From: Mike Bayer Date: Sat, 5 Sep 2020 23:45:04 +0000 (-0400) Subject: Don't rely on string col name in adapt_to_context X-Git-Tag: rel_1_4_0b1~126 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b0e9083eb2a786670a1a129d7968d768d1c4ab42;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Don't rely on string col name in adapt_to_context fixed an issue where even though the method claims to be matching up columns positionally, it was failing on that by looking in "keymap" based on string name. Adds a new member to the _keymap recs MD_RESULT_MAP_INDEX so that we can efficiently link from the generated keymap back to the compiled._result_columns structure without any ambiguity. Fixes: #5559 Change-Id: Ie2fa9165c16625ef860ffac1190e00575e96761f --- diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 1b48509b4c..43afa3628a 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -34,11 +34,12 @@ _UNPICKLED = util.symbol("unpickled") # metadata entry tuple indexes. # using raw tuple is faster than namedtuple. MD_INDEX = 0 # integer index in cursor.description -MD_OBJECTS = 1 # other string keys and ColumnElement obj that can match -MD_LOOKUP_KEY = 2 # string key we usually expect for key-based lookup -MD_RENDERED_NAME = 3 # name that is usually in cursor.description -MD_PROCESSOR = 4 # callable to process a result value into a row -MD_UNTRANSLATED = 5 # raw name from cursor.description +MD_RESULT_MAP_INDEX = 1 # integer index in compiled._result_columns +MD_OBJECTS = 2 # other string keys and ColumnElement obj that can match +MD_LOOKUP_KEY = 3 # string key we usually expect for key-based lookup +MD_RENDERED_NAME = 4 # name that is usually in cursor.description +MD_PROCESSOR = 5 # callable to process a result value into a row +MD_UNTRANSLATED = 6 # raw name from cursor.description class CursorResultMetaData(ResultMetaData): @@ -49,6 +50,7 @@ class CursorResultMetaData(ResultMetaData): "case_sensitive", "_processors", "_keys", + "_keymap_by_result_column_idx", "_tuplefilter", "_translated_indexes", "_safe_for_cache" @@ -112,6 +114,7 @@ class CursorResultMetaData(ResultMetaData): as matched to those of the cached statement. """ + if not context.compiled._result_columns: return self @@ -127,13 +130,19 @@ class CursorResultMetaData(ResultMetaData): md._keymap = dict(self._keymap) - # match up new columns positionally to the result columns - for existing, new in zip( - context.compiled._result_columns, - invoked_statement._exported_columns_iterator(), + keymap_by_position = self._keymap_by_result_column_idx + + for idx, new in enumerate( + invoked_statement._exported_columns_iterator() ): - if existing[RM_NAME] in md._keymap: - md._keymap[new] = md._keymap[existing[RM_NAME]] + try: + rec = keymap_by_position[idx] + except KeyError: + # this can happen when there are bogus column entries + # in a TextualSelect + pass + else: + md._keymap[new] = rec md.case_sensitive = self.case_sensitive md._processors = self._processors @@ -141,6 +150,8 @@ class CursorResultMetaData(ResultMetaData): md._tuplefilter = None md._translated_indexes = None md._keys = self._keys + md._keymap_by_result_column_idx = self._keymap_by_result_column_idx + md._safe_for_cache = self._safe_for_cache return md def __init__(self, parent, cursor_description): @@ -186,6 +197,12 @@ class CursorResultMetaData(ResultMetaData): metadata_entry[MD_PROCESSOR] for metadata_entry in raw ] + if context.compiled: + self._keymap_by_result_column_idx = { + metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry + for metadata_entry in raw + } + # keymap by primary string... by_key = dict( [ @@ -237,7 +254,7 @@ class CursorResultMetaData(ResultMetaData): # then for the dupe keys, put the "ambiguous column" # record into by_key. - by_key.update({key: (None, (), key) for key in dupes}) + by_key.update({key: (None, None, (), key) for key in dupes}) else: # no dupes - copy secondary elements from compiled @@ -350,6 +367,7 @@ class CursorResultMetaData(ResultMetaData): self._safe_for_cache = True return [ ( + idx, idx, rmap_entry[RM_OBJECTS], rmap_entry[RM_NAME].lower() @@ -399,6 +417,7 @@ class CursorResultMetaData(ResultMetaData): return [ ( idx, + ridx, obj, cursor_colname, cursor_colname, @@ -409,6 +428,7 @@ class CursorResultMetaData(ResultMetaData): ) for ( idx, + ridx, cursor_colname, mapped_type, coltype, @@ -480,6 +500,7 @@ class CursorResultMetaData(ResultMetaData): if idx < num_ctx_cols: ctx_rec = result_columns[idx] obj = ctx_rec[RM_OBJECTS] + ridx = idx mapped_type = ctx_rec[RM_TYPE] if obj[0] in seen: raise exc.InvalidRequestError( @@ -490,7 +511,8 @@ class CursorResultMetaData(ResultMetaData): else: mapped_type = sqltypes.NULLTYPE obj = None - yield idx, colname, mapped_type, coltype, obj, untranslated + ridx = None + yield idx, ridx, colname, mapped_type, coltype, obj, untranslated def _merge_cols_by_name( self, @@ -504,7 +526,6 @@ class CursorResultMetaData(ResultMetaData): match_map = self._create_description_match_map( result_columns, case_sensitive, loose_column_name_matching ) - for ( idx, colname, @@ -516,10 +537,20 @@ class CursorResultMetaData(ResultMetaData): except KeyError: mapped_type = sqltypes.NULLTYPE obj = None + result_columns_idx = None else: obj = ctx_rec[1] mapped_type = ctx_rec[2] - yield idx, colname, mapped_type, coltype, obj, untranslated + result_columns_idx = ctx_rec[3] + yield ( + idx, + result_columns_idx, + colname, + mapped_type, + coltype, + obj, + untranslated, + ) @classmethod def _create_description_match_map( @@ -534,7 +565,7 @@ class CursorResultMetaData(ResultMetaData): """ d = {} - for elem in result_columns: + for ridx, elem in enumerate(result_columns): key = elem[RM_RENDERED_NAME] if not case_sensitive: @@ -544,10 +575,10 @@ class CursorResultMetaData(ResultMetaData): # to the existing record. if there is a duplicate column # name in the cursor description, this will allow all of those # objects to raise an ambiguous column error - e_name, e_obj, e_type = d[key] - d[key] = e_name, e_obj + elem[RM_OBJECTS], e_type + e_name, e_obj, e_type, e_ridx = d[key] + d[key] = e_name, e_obj + elem[RM_OBJECTS], e_type, ridx else: - d[key] = (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE]) + d[key] = (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE], ridx) if loose_column_name_matching: # when using a textual statement with an unordered set @@ -557,7 +588,8 @@ class CursorResultMetaData(ResultMetaData): # duplicate keys that are ambiguous will be fixed later. for r_key in elem[RM_OBJECTS]: d.setdefault( - r_key, (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE]) + r_key, + (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE], ridx), ) return d @@ -569,7 +601,15 @@ class CursorResultMetaData(ResultMetaData): untranslated, coltype, ) in self._colnames_from_description(context, cursor_description): - yield idx, colname, sqltypes.NULLTYPE, coltype, None, untranslated + yield ( + idx, + None, + colname, + sqltypes.NULLTYPE, + coltype, + None, + untranslated, + ) def _key_fallback(self, key, err, raiseerr=True): if raiseerr: @@ -637,7 +677,7 @@ class CursorResultMetaData(ResultMetaData): def __getstate__(self): return { "_keymap": { - key: (rec[MD_INDEX], _UNPICKLED, key) + key: (rec[MD_INDEX], rec[MD_RESULT_MAP_INDEX], _UNPICKLED, key) for key, rec in self._keymap.items() if isinstance(key, util.string_types + util.int_types) }, @@ -651,6 +691,9 @@ class CursorResultMetaData(ResultMetaData): self._processors = [None for _ in range(len(state["_keys"]))] self._keymap = state["_keymap"] + self._keymap_by_result_column_idx = { + rec[MD_RESULT_MAP_INDEX]: rec for rec in self._keymap.values() + } self._keys = state["_keys"] self.case_sensitive = state["case_sensitive"] @@ -1220,6 +1263,7 @@ class BaseCursorResult(object): self._metadata = _NO_RESULT_METADATA def _init_metadata(self, context, cursor_description): + if context.compiled: if context.compiled._cached_metadata: metadata = self.context.compiled._cached_metadata diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py index a4a8cd99f7..8bfaded8fe 100644 --- a/test/dialect/oracle/test_compiler.py +++ b/test/dialect/oracle/test_compiler.py @@ -545,7 +545,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): stmt = select(type_coerce(column("x"), MyType).label("foo")).limit(1) dialect = oracle.dialect() compiled = stmt.compile(dialect=dialect) - assert isinstance(compiled._create_result_map()["foo"][-1], MyType) + assert isinstance(compiled._create_result_map()["foo"][-2], MyType) def test_use_binds_for_limits_disabled_one(self): t = table("sometable", column("col1"), column("col2")) @@ -1061,8 +1061,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): eq_( compiled._create_result_map(), { - "c3": ("c3", (t1.c.c3, "c3", "c3"), t1.c.c3.type), - "lower": ("lower", (fn, "lower", None), fn.type), + "c3": ("c3", (t1.c.c3, "c3", "c3"), t1.c.c3.type, 1), + "lower": ("lower", (fn, "lower", None), fn.type, 0), }, ) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 1084d30cb0..b43d09045d 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -4716,6 +4716,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL): "here_yetagain_anotherid", ), t1.c.anotherid.type, + 0, ) }, ) @@ -5200,8 +5201,8 @@ class ResultMapTest(fixtures.TestBase): eq_( comp._create_result_map(), { - "a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type), - "b": ("b", (t.c.b, "b", "b", "t_b"), t.c.b.type), + "a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type, 0), + "b": ("b", (t.c.b, "b", "b", "t_b"), t.c.b.type, 1), }, ) @@ -5212,7 +5213,7 @@ class ResultMapTest(fixtures.TestBase): comp = stmt.compile() eq_( comp._create_result_map(), - {"a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type)}, + {"a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type, 0)}, ) def test_compound_only_top_populates(self): @@ -5221,7 +5222,7 @@ class ResultMapTest(fixtures.TestBase): comp = stmt.compile() eq_( comp._create_result_map(), - {"a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type)}, + {"a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type, 0)}, ) def test_label_plus_element(self): @@ -5234,12 +5235,13 @@ class ResultMapTest(fixtures.TestBase): eq_( comp._create_result_map(), { - "a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type), - "bar": ("bar", (l1, "bar"), l1.type), + "a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type, 0), + "bar": ("bar", (l1, "bar"), l1.type, 1), "anon_1": ( tc.anon_label, (tc_anon_label, "anon_1", tc), tc.type, + 2, ), }, ) @@ -5279,7 +5281,7 @@ class ResultMapTest(fixtures.TestBase): comp = stmt.compile(dialect=postgresql.dialect()) eq_( comp._create_result_map(), - {"a": ("a", (aint, "a", "a", "t2_a"), aint.type)}, + {"a": ("a", (aint, "a", "a", "t2_a"), aint.type, 0)}, ) def test_insert_from_select(self): @@ -5293,7 +5295,7 @@ class ResultMapTest(fixtures.TestBase): comp = stmt.compile(dialect=postgresql.dialect()) eq_( comp._create_result_map(), - {"a": ("a", (aint, "a", "a", "t2_a"), aint.type)}, + {"a": ("a", (aint, "a", "a", "t2_a"), aint.type, 0)}, ) def test_nested_api(self): @@ -5339,6 +5341,7 @@ class ResultMapTest(fixtures.TestBase): "myothertable_otherid", ), table2.c.otherid.type, + 0, ), "othername": ( "othername", @@ -5349,8 +5352,9 @@ class ResultMapTest(fixtures.TestBase): "myothertable_othername", ), table2.c.othername.type, + 1, ), - "k1": ("k1", (1, 2, 3), int_), + "k1": ("k1", (1, 2, 3), int_, 2), }, ) eq_( @@ -5360,12 +5364,14 @@ class ResultMapTest(fixtures.TestBase): "myid", (table1.c.myid, "myid", "myid", "mytable_myid"), table1.c.myid.type, + 0, ), - "k2": ("k2", (3, 4, 5), int_), + "k2": ("k2", (3, 4, 5), int_, 3), "name": ( "name", (table1.c.name, "name", "name", "mytable_name"), table1.c.name.type, + 1, ), "description": ( "description", @@ -5376,6 +5382,7 @@ class ResultMapTest(fixtures.TestBase): "mytable_description", ), table1.c.description.type, + 2, ), }, ) diff --git a/test/sql/test_deprecations.py b/test/sql/test_deprecations.py index c83c71ada3..04fed9b6e6 100644 --- a/test/sql/test_deprecations.py +++ b/test/sql/test_deprecations.py @@ -811,6 +811,7 @@ class TextualSelectTest(fixtures.TestBase, AssertsCompiledSQL): "myid", (table1.c.myid, "myid", "myid", "mytable_myid"), table1.c.myid.type, + 0, ) }, ) diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 44c1565e40..578e20e445 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -1475,6 +1475,16 @@ class KeyTargetingTest(fixtures.TablesTest): schema=testing.config.test_schema, ) + Table( + "users", + metadata, + Column("id", Integer, primary_key=True), + Column("team_id", metadata, ForeignKey("teams.id")), + ) + Table( + "teams", metadata, Column("id", Integer, primary_key=True), + ) + @classmethod def insert_data(cls, connection): conn = connection @@ -1484,6 +1494,9 @@ class KeyTargetingTest(fixtures.TablesTest): conn.execute(cls.tables.keyed4.insert(), dict(b="b4", q="q4")) conn.execute(cls.tables.content.insert(), dict(type="t1")) + conn.execute(cls.tables.teams.insert(), dict(id=1)) + conn.execute(cls.tables.users.insert(), dict(id=1, team_id=1)) + if testing.requires.schemas.enabled: conn.execute( cls.tables["%s.wschema" % testing.config.test_schema].insert(), @@ -1815,7 +1828,7 @@ class KeyTargetingTest(fixtures.TablesTest): def _adapt_result_columns_fixture_two(self): return text("select a AS keyed2_a, b AS keyed2_b from keyed2").columns( - keyed2_a=CHAR, keyed2_b=CHAR + column("keyed2_a", CHAR), column("keyed2_b", CHAR) ) def _adapt_result_columns_fixture_three(self): @@ -1834,11 +1847,34 @@ class KeyTargetingTest(fixtures.TablesTest): return stmt2 + def _adapt_result_columns_fixture_five(self): + users, teams = self.tables("users", "teams") + return select([users.c.id, teams.c.id]).select_from( + users.outerjoin(teams) + ) + + def _adapt_result_columns_fixture_six(self): + # this has _result_columns structure that is not ordered + # the same as the cursor.description. + return text("select a AS keyed2_a, b AS keyed2_b from keyed2").columns( + keyed2_b=CHAR, keyed2_a=CHAR, + ) + + def _adapt_result_columns_fixture_seven(self): + # this has _result_columns structure that is not ordered + # the same as the cursor.description. + return text("select a AS keyed2_a, b AS keyed2_b from keyed2").columns( + keyed2_b=CHAR, bogus_col=CHAR + ) + @testing.combinations( _adapt_result_columns_fixture_one, _adapt_result_columns_fixture_two, _adapt_result_columns_fixture_three, _adapt_result_columns_fixture_four, + _adapt_result_columns_fixture_five, + _adapt_result_columns_fixture_six, + _adapt_result_columns_fixture_seven, argnames="stmt_fn", ) def test_adapt_result_columns(self, connection, stmt_fn): @@ -1863,31 +1899,41 @@ class KeyTargetingTest(fixtures.TablesTest): zip(stmt1.selected_columns, stmt2.selected_columns) ) - result = connection.execute(stmt1) + for i in range(2): + try: + result = connection.execute(stmt1) - mock_context = Mock( - compiled=result.context.compiled, invoked_statement=stmt2 - ) - existing_metadata = result._metadata - adapted_metadata = existing_metadata._adapt_to_context(mock_context) + mock_context = Mock( + compiled=result.context.compiled, invoked_statement=stmt2 + ) + existing_metadata = result._metadata + adapted_metadata = existing_metadata._adapt_to_context( + mock_context + ) - eq_(existing_metadata.keys, adapted_metadata.keys) + eq_(existing_metadata.keys, adapted_metadata.keys) - for k in existing_metadata._keymap: - if isinstance(k, ColumnElement) and k in column_linkage: - other_k = column_linkage[k] - else: - other_k = k + for k in existing_metadata._keymap: + if isinstance(k, ColumnElement) and k in column_linkage: + other_k = column_linkage[k] + else: + other_k = k - is_( - existing_metadata._keymap[k], adapted_metadata._keymap[other_k] - ) + is_( + existing_metadata._keymap[k], + adapted_metadata._keymap[other_k], + ) + finally: + result.close() @testing.combinations( _adapt_result_columns_fixture_one, _adapt_result_columns_fixture_two, _adapt_result_columns_fixture_three, _adapt_result_columns_fixture_four, + _adapt_result_columns_fixture_five, + _adapt_result_columns_fixture_six, + _adapt_result_columns_fixture_seven, argnames="stmt_fn", ) def test_adapt_result_columns_from_cache(self, connection, stmt_fn): @@ -1909,7 +1955,10 @@ class KeyTargetingTest(fixtures.TablesTest): row = result.first() for col in stmt2.selected_columns: - assert col in row._mapping + if "bogus" in col.name: + assert col not in row._mapping + else: + assert col in row._mapping class PositionalTextTest(fixtures.TablesTest): diff --git a/test/sql/test_text.py b/test/sql/test_text.py index 1a7ee6f344..9d5ab65ed9 100644 --- a/test/sql/test_text.py +++ b/test/sql/test_text.py @@ -435,6 +435,8 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): column("id", Integer), column("name") ) + col_pos = {col.name: idx for idx, col in enumerate(t.selected_columns)} + compiled = t.compile() eq_( compiled._create_result_map(), @@ -443,11 +445,13 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): "id", (t.selected_columns.id, "id", "id", "id"), t.selected_columns.id.type, + col_pos["id"], ), "name": ( "name", (t.selected_columns.name, "name", "name", "name"), t.selected_columns.name.type, + col_pos["name"], ), }, ) @@ -455,6 +459,8 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): def test_basic_toplevel_resultmap(self): t = text("select id, name from user").columns(id=Integer, name=String) + col_pos = {col.name: idx for idx, col in enumerate(t.selected_columns)} + compiled = t.compile() eq_( compiled._create_result_map(), @@ -463,11 +469,13 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): "id", (t.selected_columns.id, "id", "id", "id"), t.selected_columns.id.type, + col_pos["id"], ), "name": ( "name", (t.selected_columns.name, "name", "name", "name"), t.selected_columns.name.type, + col_pos["name"], ), }, ) @@ -490,6 +498,7 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL): "myid", (table1.c.myid, "myid", "myid", "mytable_myid"), table1.c.myid.type, + 0, ) }, )