]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Don't rely on string col name in adapt_to_context
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 5 Sep 2020 23:45:04 +0000 (19:45 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 6 Sep 2020 13:55:27 +0000 (09:55 -0400)
fixed an issue where even though the method claims to be
matching up columns positionally, it was failing on that by
looking in "keymap" based on string name.

Adds a new member to the _keymap recs MD_RESULT_MAP_INDEX
so that we can efficiently link from the generated keymap
back to the compiled._result_columns structure without
any ambiguity.

Fixes: #5559
Change-Id: Ie2fa9165c16625ef860ffac1190e00575e96761f

lib/sqlalchemy/engine/cursor.py
test/dialect/oracle/test_compiler.py
test/sql/test_compiler.py
test/sql/test_deprecations.py
test/sql/test_resultset.py
test/sql/test_text.py

index 1b48509b4ccdad2f4aee90c54957c08cdeb29f7a..43afa3628afb4bb27ce6c240d05347f7fa059e2c 100644 (file)
@@ -34,11 +34,12 @@ _UNPICKLED = util.symbol("unpickled")
 # metadata entry tuple indexes.
 # using raw tuple is faster than namedtuple.
 MD_INDEX = 0  # integer index in cursor.description
-MD_OBJECTS = 1  # other string keys and ColumnElement obj that can match
-MD_LOOKUP_KEY = 2  # string key we usually expect for key-based lookup
-MD_RENDERED_NAME = 3  # name that is usually in cursor.description
-MD_PROCESSOR = 4  # callable to process a result value into a row
-MD_UNTRANSLATED = 5  # raw name from cursor.description
+MD_RESULT_MAP_INDEX = 1  # integer index in compiled._result_columns
+MD_OBJECTS = 2  # other string keys and ColumnElement obj that can match
+MD_LOOKUP_KEY = 3  # string key we usually expect for key-based lookup
+MD_RENDERED_NAME = 4  # name that is usually in cursor.description
+MD_PROCESSOR = 5  # callable to process a result value into a row
+MD_UNTRANSLATED = 6  # raw name from cursor.description
 
 
 class CursorResultMetaData(ResultMetaData):
@@ -49,6 +50,7 @@ class CursorResultMetaData(ResultMetaData):
         "case_sensitive",
         "_processors",
         "_keys",
+        "_keymap_by_result_column_idx",
         "_tuplefilter",
         "_translated_indexes",
         "_safe_for_cache"
@@ -112,6 +114,7 @@ class CursorResultMetaData(ResultMetaData):
         as matched to those of the cached statement.
 
         """
+
         if not context.compiled._result_columns:
             return self
 
@@ -127,13 +130,19 @@ class CursorResultMetaData(ResultMetaData):
 
         md._keymap = dict(self._keymap)
 
-        # match up new columns positionally to the result columns
-        for existing, new in zip(
-            context.compiled._result_columns,
-            invoked_statement._exported_columns_iterator(),
+        keymap_by_position = self._keymap_by_result_column_idx
+
+        for idx, new in enumerate(
+            invoked_statement._exported_columns_iterator()
         ):
-            if existing[RM_NAME] in md._keymap:
-                md._keymap[new] = md._keymap[existing[RM_NAME]]
+            try:
+                rec = keymap_by_position[idx]
+            except KeyError:
+                # this can happen when there are bogus column entries
+                # in a TextualSelect
+                pass
+            else:
+                md._keymap[new] = rec
 
         md.case_sensitive = self.case_sensitive
         md._processors = self._processors
@@ -141,6 +150,8 @@ class CursorResultMetaData(ResultMetaData):
         md._tuplefilter = None
         md._translated_indexes = None
         md._keys = self._keys
+        md._keymap_by_result_column_idx = self._keymap_by_result_column_idx
+        md._safe_for_cache = self._safe_for_cache
         return md
 
     def __init__(self, parent, cursor_description):
@@ -186,6 +197,12 @@ class CursorResultMetaData(ResultMetaData):
             metadata_entry[MD_PROCESSOR] for metadata_entry in raw
         ]
 
+        if context.compiled:
+            self._keymap_by_result_column_idx = {
+                metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry
+                for metadata_entry in raw
+            }
+
         # keymap by primary string...
         by_key = dict(
             [
@@ -237,7 +254,7 @@ class CursorResultMetaData(ResultMetaData):
 
                 # then for the dupe keys, put the "ambiguous column"
                 # record into by_key.
-                by_key.update({key: (None, (), key) for key in dupes})
+                by_key.update({key: (None, None, (), key) for key in dupes})
 
             else:
                 # no dupes - copy secondary elements from compiled
@@ -350,6 +367,7 @@ class CursorResultMetaData(ResultMetaData):
             self._safe_for_cache = True
             return [
                 (
+                    idx,
                     idx,
                     rmap_entry[RM_OBJECTS],
                     rmap_entry[RM_NAME].lower()
@@ -399,6 +417,7 @@ class CursorResultMetaData(ResultMetaData):
             return [
                 (
                     idx,
+                    ridx,
                     obj,
                     cursor_colname,
                     cursor_colname,
@@ -409,6 +428,7 @@ class CursorResultMetaData(ResultMetaData):
                 )
                 for (
                     idx,
+                    ridx,
                     cursor_colname,
                     mapped_type,
                     coltype,
@@ -480,6 +500,7 @@ class CursorResultMetaData(ResultMetaData):
             if idx < num_ctx_cols:
                 ctx_rec = result_columns[idx]
                 obj = ctx_rec[RM_OBJECTS]
+                ridx = idx
                 mapped_type = ctx_rec[RM_TYPE]
                 if obj[0] in seen:
                     raise exc.InvalidRequestError(
@@ -490,7 +511,8 @@ class CursorResultMetaData(ResultMetaData):
             else:
                 mapped_type = sqltypes.NULLTYPE
                 obj = None
-            yield idx, colname, mapped_type, coltype, obj, untranslated
+                ridx = None
+            yield idx, ridx, colname, mapped_type, coltype, obj, untranslated
 
     def _merge_cols_by_name(
         self,
@@ -504,7 +526,6 @@ class CursorResultMetaData(ResultMetaData):
         match_map = self._create_description_match_map(
             result_columns, case_sensitive, loose_column_name_matching
         )
-
         for (
             idx,
             colname,
@@ -516,10 +537,20 @@ class CursorResultMetaData(ResultMetaData):
             except KeyError:
                 mapped_type = sqltypes.NULLTYPE
                 obj = None
+                result_columns_idx = None
             else:
                 obj = ctx_rec[1]
                 mapped_type = ctx_rec[2]
-            yield idx, colname, mapped_type, coltype, obj, untranslated
+                result_columns_idx = ctx_rec[3]
+            yield (
+                idx,
+                result_columns_idx,
+                colname,
+                mapped_type,
+                coltype,
+                obj,
+                untranslated,
+            )
 
     @classmethod
     def _create_description_match_map(
@@ -534,7 +565,7 @@ class CursorResultMetaData(ResultMetaData):
         """
 
         d = {}
-        for elem in result_columns:
+        for ridx, elem in enumerate(result_columns):
             key = elem[RM_RENDERED_NAME]
 
             if not case_sensitive:
@@ -544,10 +575,10 @@ class CursorResultMetaData(ResultMetaData):
                 # to the existing record.  if there is a duplicate column
                 # name in the cursor description, this will allow all of those
                 # objects to raise an ambiguous column error
-                e_name, e_obj, e_type = d[key]
-                d[key] = e_name, e_obj + elem[RM_OBJECTS], e_type
+                e_name, e_obj, e_type, e_ridx = d[key]
+                d[key] = e_name, e_obj + elem[RM_OBJECTS], e_type, ridx
             else:
-                d[key] = (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE])
+                d[key] = (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE], ridx)
 
             if loose_column_name_matching:
                 # when using a textual statement with an unordered set
@@ -557,7 +588,8 @@ class CursorResultMetaData(ResultMetaData):
                 # duplicate keys that are ambiguous will be fixed later.
                 for r_key in elem[RM_OBJECTS]:
                     d.setdefault(
-                        r_key, (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE])
+                        r_key,
+                        (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE], ridx),
                     )
 
         return d
@@ -569,7 +601,15 @@ class CursorResultMetaData(ResultMetaData):
             untranslated,
             coltype,
         ) in self._colnames_from_description(context, cursor_description):
-            yield idx, colname, sqltypes.NULLTYPE, coltype, None, untranslated
+            yield (
+                idx,
+                None,
+                colname,
+                sqltypes.NULLTYPE,
+                coltype,
+                None,
+                untranslated,
+            )
 
     def _key_fallback(self, key, err, raiseerr=True):
         if raiseerr:
@@ -637,7 +677,7 @@ class CursorResultMetaData(ResultMetaData):
     def __getstate__(self):
         return {
             "_keymap": {
-                key: (rec[MD_INDEX], _UNPICKLED, key)
+                key: (rec[MD_INDEX], rec[MD_RESULT_MAP_INDEX], _UNPICKLED, key)
                 for key, rec in self._keymap.items()
                 if isinstance(key, util.string_types + util.int_types)
             },
@@ -651,6 +691,9 @@ class CursorResultMetaData(ResultMetaData):
         self._processors = [None for _ in range(len(state["_keys"]))]
         self._keymap = state["_keymap"]
 
+        self._keymap_by_result_column_idx = {
+            rec[MD_RESULT_MAP_INDEX]: rec for rec in self._keymap.values()
+        }
         self._keys = state["_keys"]
         self.case_sensitive = state["case_sensitive"]
 
@@ -1220,6 +1263,7 @@ class BaseCursorResult(object):
             self._metadata = _NO_RESULT_METADATA
 
     def _init_metadata(self, context, cursor_description):
+
         if context.compiled:
             if context.compiled._cached_metadata:
                 metadata = self.context.compiled._cached_metadata
index a4a8cd99f7cdbeaccc84188d877278e43114abe8..8bfaded8fef5d0a6e1a47080b2e35e9d3abee50c 100644 (file)
@@ -545,7 +545,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         stmt = select(type_coerce(column("x"), MyType).label("foo")).limit(1)
         dialect = oracle.dialect()
         compiled = stmt.compile(dialect=dialect)
-        assert isinstance(compiled._create_result_map()["foo"][-1], MyType)
+        assert isinstance(compiled._create_result_map()["foo"][-2], MyType)
 
     def test_use_binds_for_limits_disabled_one(self):
         t = table("sometable", column("col1"), column("col2"))
@@ -1061,8 +1061,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         eq_(
             compiled._create_result_map(),
             {
-                "c3": ("c3", (t1.c.c3, "c3", "c3"), t1.c.c3.type),
-                "lower": ("lower", (fn, "lower", None), fn.type),
+                "c3": ("c3", (t1.c.c3, "c3", "c3"), t1.c.c3.type, 1),
+                "lower": ("lower", (fn, "lower", None), fn.type, 0),
             },
         )
 
index 1084d30cb0fc02a922348af1fab9fcdb3e5b727d..b43d09045d58bca8ea270ad2618ad806a052823d 100644 (file)
@@ -4716,6 +4716,7 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL):
                         "here_yetagain_anotherid",
                     ),
                     t1.c.anotherid.type,
+                    0,
                 )
             },
         )
@@ -5200,8 +5201,8 @@ class ResultMapTest(fixtures.TestBase):
         eq_(
             comp._create_result_map(),
             {
-                "a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type),
-                "b": ("b", (t.c.b, "b", "b", "t_b"), t.c.b.type),
+                "a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type, 0),
+                "b": ("b", (t.c.b, "b", "b", "t_b"), t.c.b.type, 1),
             },
         )
 
@@ -5212,7 +5213,7 @@ class ResultMapTest(fixtures.TestBase):
         comp = stmt.compile()
         eq_(
             comp._create_result_map(),
-            {"a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type)},
+            {"a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type, 0)},
         )
 
     def test_compound_only_top_populates(self):
@@ -5221,7 +5222,7 @@ class ResultMapTest(fixtures.TestBase):
         comp = stmt.compile()
         eq_(
             comp._create_result_map(),
-            {"a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type)},
+            {"a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type, 0)},
         )
 
     def test_label_plus_element(self):
@@ -5234,12 +5235,13 @@ class ResultMapTest(fixtures.TestBase):
         eq_(
             comp._create_result_map(),
             {
-                "a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type),
-                "bar": ("bar", (l1, "bar"), l1.type),
+                "a": ("a", (t.c.a, "a", "a", "t_a"), t.c.a.type, 0),
+                "bar": ("bar", (l1, "bar"), l1.type, 1),
                 "anon_1": (
                     tc.anon_label,
                     (tc_anon_label, "anon_1", tc),
                     tc.type,
+                    2,
                 ),
             },
         )
@@ -5279,7 +5281,7 @@ class ResultMapTest(fixtures.TestBase):
         comp = stmt.compile(dialect=postgresql.dialect())
         eq_(
             comp._create_result_map(),
-            {"a": ("a", (aint, "a", "a", "t2_a"), aint.type)},
+            {"a": ("a", (aint, "a", "a", "t2_a"), aint.type, 0)},
         )
 
     def test_insert_from_select(self):
@@ -5293,7 +5295,7 @@ class ResultMapTest(fixtures.TestBase):
         comp = stmt.compile(dialect=postgresql.dialect())
         eq_(
             comp._create_result_map(),
-            {"a": ("a", (aint, "a", "a", "t2_a"), aint.type)},
+            {"a": ("a", (aint, "a", "a", "t2_a"), aint.type, 0)},
         )
 
     def test_nested_api(self):
@@ -5339,6 +5341,7 @@ class ResultMapTest(fixtures.TestBase):
                         "myothertable_otherid",
                     ),
                     table2.c.otherid.type,
+                    0,
                 ),
                 "othername": (
                     "othername",
@@ -5349,8 +5352,9 @@ class ResultMapTest(fixtures.TestBase):
                         "myothertable_othername",
                     ),
                     table2.c.othername.type,
+                    1,
                 ),
-                "k1": ("k1", (1, 2, 3), int_),
+                "k1": ("k1", (1, 2, 3), int_, 2),
             },
         )
         eq_(
@@ -5360,12 +5364,14 @@ class ResultMapTest(fixtures.TestBase):
                     "myid",
                     (table1.c.myid, "myid", "myid", "mytable_myid"),
                     table1.c.myid.type,
+                    0,
                 ),
-                "k2": ("k2", (3, 4, 5), int_),
+                "k2": ("k2", (3, 4, 5), int_, 3),
                 "name": (
                     "name",
                     (table1.c.name, "name", "name", "mytable_name"),
                     table1.c.name.type,
+                    1,
                 ),
                 "description": (
                     "description",
@@ -5376,6 +5382,7 @@ class ResultMapTest(fixtures.TestBase):
                         "mytable_description",
                     ),
                     table1.c.description.type,
+                    2,
                 ),
             },
         )
index c83c71ada36894f38df5b668f436647f579a67f1..04fed9b6e6c390933a5cf0b410c725d173db47dd 100644 (file)
@@ -811,6 +811,7 @@ class TextualSelectTest(fixtures.TestBase, AssertsCompiledSQL):
                     "myid",
                     (table1.c.myid, "myid", "myid", "mytable_myid"),
                     table1.c.myid.type,
+                    0,
                 )
             },
         )
index 44c1565e403f43d0559a97dedf8a179792c9267c..578e20e445133155ffe3927d3ce1700266cf293f 100644 (file)
@@ -1475,6 +1475,16 @@ class KeyTargetingTest(fixtures.TablesTest):
                 schema=testing.config.test_schema,
             )
 
+        Table(
+            "users",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("team_id", metadata, ForeignKey("teams.id")),
+        )
+        Table(
+            "teams", metadata, Column("id", Integer, primary_key=True),
+        )
+
     @classmethod
     def insert_data(cls, connection):
         conn = connection
@@ -1484,6 +1494,9 @@ class KeyTargetingTest(fixtures.TablesTest):
         conn.execute(cls.tables.keyed4.insert(), dict(b="b4", q="q4"))
         conn.execute(cls.tables.content.insert(), dict(type="t1"))
 
+        conn.execute(cls.tables.teams.insert(), dict(id=1))
+        conn.execute(cls.tables.users.insert(), dict(id=1, team_id=1))
+
         if testing.requires.schemas.enabled:
             conn.execute(
                 cls.tables["%s.wschema" % testing.config.test_schema].insert(),
@@ -1815,7 +1828,7 @@ class KeyTargetingTest(fixtures.TablesTest):
 
     def _adapt_result_columns_fixture_two(self):
         return text("select a AS keyed2_a, b AS keyed2_b from keyed2").columns(
-            keyed2_a=CHAR, keyed2_b=CHAR
+            column("keyed2_a", CHAR), column("keyed2_b", CHAR)
         )
 
     def _adapt_result_columns_fixture_three(self):
@@ -1834,11 +1847,34 @@ class KeyTargetingTest(fixtures.TablesTest):
 
         return stmt2
 
+    def _adapt_result_columns_fixture_five(self):
+        users, teams = self.tables("users", "teams")
+        return select([users.c.id, teams.c.id]).select_from(
+            users.outerjoin(teams)
+        )
+
+    def _adapt_result_columns_fixture_six(self):
+        # this has _result_columns structure that is not ordered
+        # the same as the cursor.description.
+        return text("select a AS keyed2_a, b AS keyed2_b from keyed2").columns(
+            keyed2_b=CHAR, keyed2_a=CHAR,
+        )
+
+    def _adapt_result_columns_fixture_seven(self):
+        # this has _result_columns structure that is not ordered
+        # the same as the cursor.description.
+        return text("select a AS keyed2_a, b AS keyed2_b from keyed2").columns(
+            keyed2_b=CHAR, bogus_col=CHAR
+        )
+
     @testing.combinations(
         _adapt_result_columns_fixture_one,
         _adapt_result_columns_fixture_two,
         _adapt_result_columns_fixture_three,
         _adapt_result_columns_fixture_four,
+        _adapt_result_columns_fixture_five,
+        _adapt_result_columns_fixture_six,
+        _adapt_result_columns_fixture_seven,
         argnames="stmt_fn",
     )
     def test_adapt_result_columns(self, connection, stmt_fn):
@@ -1863,31 +1899,41 @@ class KeyTargetingTest(fixtures.TablesTest):
             zip(stmt1.selected_columns, stmt2.selected_columns)
         )
 
-        result = connection.execute(stmt1)
+        for i in range(2):
+            try:
+                result = connection.execute(stmt1)
 
-        mock_context = Mock(
-            compiled=result.context.compiled, invoked_statement=stmt2
-        )
-        existing_metadata = result._metadata
-        adapted_metadata = existing_metadata._adapt_to_context(mock_context)
+                mock_context = Mock(
+                    compiled=result.context.compiled, invoked_statement=stmt2
+                )
+                existing_metadata = result._metadata
+                adapted_metadata = existing_metadata._adapt_to_context(
+                    mock_context
+                )
 
-        eq_(existing_metadata.keys, adapted_metadata.keys)
+                eq_(existing_metadata.keys, adapted_metadata.keys)
 
-        for k in existing_metadata._keymap:
-            if isinstance(k, ColumnElement) and k in column_linkage:
-                other_k = column_linkage[k]
-            else:
-                other_k = k
+                for k in existing_metadata._keymap:
+                    if isinstance(k, ColumnElement) and k in column_linkage:
+                        other_k = column_linkage[k]
+                    else:
+                        other_k = k
 
-            is_(
-                existing_metadata._keymap[k], adapted_metadata._keymap[other_k]
-            )
+                    is_(
+                        existing_metadata._keymap[k],
+                        adapted_metadata._keymap[other_k],
+                    )
+            finally:
+                result.close()
 
     @testing.combinations(
         _adapt_result_columns_fixture_one,
         _adapt_result_columns_fixture_two,
         _adapt_result_columns_fixture_three,
         _adapt_result_columns_fixture_four,
+        _adapt_result_columns_fixture_five,
+        _adapt_result_columns_fixture_six,
+        _adapt_result_columns_fixture_seven,
         argnames="stmt_fn",
     )
     def test_adapt_result_columns_from_cache(self, connection, stmt_fn):
@@ -1909,7 +1955,10 @@ class KeyTargetingTest(fixtures.TablesTest):
 
         row = result.first()
         for col in stmt2.selected_columns:
-            assert col in row._mapping
+            if "bogus" in col.name:
+                assert col not in row._mapping
+            else:
+                assert col in row._mapping
 
 
 class PositionalTextTest(fixtures.TablesTest):
index 1a7ee6f344de53bebdb79c73d8e5e9c2bb2e51d8..9d5ab65ed97fd90628a7f0521c0301301defd689 100644 (file)
@@ -435,6 +435,8 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL):
             column("id", Integer), column("name")
         )
 
+        col_pos = {col.name: idx for idx, col in enumerate(t.selected_columns)}
+
         compiled = t.compile()
         eq_(
             compiled._create_result_map(),
@@ -443,11 +445,13 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL):
                     "id",
                     (t.selected_columns.id, "id", "id", "id"),
                     t.selected_columns.id.type,
+                    col_pos["id"],
                 ),
                 "name": (
                     "name",
                     (t.selected_columns.name, "name", "name", "name"),
                     t.selected_columns.name.type,
+                    col_pos["name"],
                 ),
             },
         )
@@ -455,6 +459,8 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_basic_toplevel_resultmap(self):
         t = text("select id, name from user").columns(id=Integer, name=String)
 
+        col_pos = {col.name: idx for idx, col in enumerate(t.selected_columns)}
+
         compiled = t.compile()
         eq_(
             compiled._create_result_map(),
@@ -463,11 +469,13 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL):
                     "id",
                     (t.selected_columns.id, "id", "id", "id"),
                     t.selected_columns.id.type,
+                    col_pos["id"],
                 ),
                 "name": (
                     "name",
                     (t.selected_columns.name, "name", "name", "name"),
                     t.selected_columns.name.type,
+                    col_pos["name"],
                 ),
             },
         )
@@ -490,6 +498,7 @@ class AsFromTest(fixtures.TestBase, AssertsCompiledSQL):
                     "myid",
                     (table1.c.myid, "myid", "myid", "mytable_myid"),
                     table1.c.myid.type,
+                    0,
                 )
             },
         )