]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
track item schema names to identify name collisions w/ default schema
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 17 Dec 2021 23:04:47 +0000 (18:04 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 14 Jan 2022 21:54:13 +0000 (16:54 -0500)
Added an additional lookup step to the compiler which will track all FROM
clauses which are tables, that may have the same name shared in multiple
schemas where one of the schemas is the implicit "default" schema; in this
case, the table name when referring to that name without a schema
qualification will be rendered with an anonymous alias name at the compiler
level in order to disambiguate the two (or more) names. The approach of
schema-qualifying the normally unqualified name with the server-detected
"default schema name" value was also considered, however this approach
doesn't apply to Oracle nor is it accepted by SQL Server, nor would it work
with multiple entries in the PostgreSQL search path. The name collision
issue resolved here has been identified as affecting at least Oracle,
PostgreSQL, SQL Server, MySQL and MariaDB.

Fixes: #7471
Change-Id: Id65e7ca8c43fe8d95777084e8d5ec140ebcd784d

doc/build/changelog/unreleased_20/7471.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/testing/suite/test_select.py
test/requirements.py
test/sql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_20/7471.rst b/doc/build/changelog/unreleased_20/7471.rst
new file mode 100644 (file)
index 0000000..344bc27
--- /dev/null
@@ -0,0 +1,17 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 7471
+
+    Added an additional lookup step to the compiler which will track all FROM
+    clauses which are tables, that may have the same name shared in multiple
+    schemas where one of the schemas is the implicit "default" schema; in this
+    case, the table name when referring to that name without a schema
+    qualification will be rendered with an anonymous alias name at the compiler
+    level in order to disambiguate the two (or more) names. The approach of
+    schema-qualifying the normally unqualified name with the server-detected
+    "default schema name" value was also considered, however this approach
+    doesn't apply to Oracle nor is it accepted by SQL Server, nor would it work
+    with multiple entries in the PostgreSQL search path. The name collision
+    issue resolved here has been identified as affecting at least Oracle,
+    PostgreSQL, SQL Server, MySQL and MariaDB.
+
index cba7cf07dcc2d9378373acb03ee8d9a104fbd573..4cff8defb9042a8323c01e3f63be68675e4b4e50 100644 (file)
@@ -683,7 +683,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
         self._setup_for_generate()
 
         SelectState.__init__(self, self.statement, compiler, **kw)
-
         return self
 
     def _dump_option_struct(self):
index 6ab9a75f6f01857280123a141c7b0cd2c9f05ca4..ae586c9f2e115e27dffe2f8643bc9efe1bfa4fbd 100644 (file)
@@ -499,7 +499,7 @@ class CompileState:
 
     """
 
-    __slots__ = ("statement",)
+    __slots__ = ("statement", "_ambiguous_table_name_map")
 
     plugins = {}
 
index cb10811c6aa461dd024e64db9aa2ab451c614c3c..af39f06729b102c9a1f4380c87db2493cfb3eec7 100644 (file)
@@ -1466,6 +1466,7 @@ class SQLCompiler(Compiled):
         add_to_result_map=None,
         include_table=True,
         result_map_targets=(),
+        ambiguous_table_name_map=None,
         **kwargs,
     ):
         name = orig_name = column.name
@@ -1502,6 +1503,14 @@ class SQLCompiler(Compiled):
             else:
                 schema_prefix = ""
             tablename = table.name
+
+            if (
+                not effective_schema
+                and ambiguous_table_name_map
+                and tablename in ambiguous_table_name_map
+            ):
+                tablename = ambiguous_table_name_map[tablename]
+
             if isinstance(tablename, elements._truncated_label):
                 tablename = self._truncated_identifier("alias", tablename)
 
@@ -3252,6 +3261,10 @@ class SQLCompiler(Compiled):
         compile_state = select_stmt._compile_state_factory(
             select_stmt, self, **kwargs
         )
+        kwargs[
+            "ambiguous_table_name_map"
+        ] = compile_state._ambiguous_table_name_map
+
         select_stmt = compile_state.statement
 
         toplevel = not self.stack
@@ -3732,6 +3745,7 @@ class SQLCompiler(Compiled):
         fromhints=None,
         use_schema=True,
         from_linter=None,
+        ambiguous_table_name_map=None,
         **kwargs,
     ):
         if from_linter:
@@ -3748,6 +3762,20 @@ class SQLCompiler(Compiled):
                 )
             else:
                 ret = self.preparer.quote(table.name)
+
+                if (
+                    not effective_schema
+                    and ambiguous_table_name_map
+                    and table.name in ambiguous_table_name_map
+                ):
+                    anon_name = self._truncated_identifier(
+                        "alias", ambiguous_table_name_map[table.name]
+                    )
+
+                    ret = ret + self.get_render_as_alias_suffix(
+                        self.preparer.format_alias(None, anon_name)
+                    )
+
             if fromhints and table in fromhints:
                 ret = self.format_from_hint_text(
                     ret, table, fromhints[table], iscrud
index a025cce357fb7bb0e1efb7720663279122aefe5e..1fa312b7e001f4be612ad8c7bb015a32736d5ef8 100644 (file)
@@ -286,6 +286,7 @@ class ClauseElement(
     is_clause_element = True
     is_selectable = False
 
+    _is_table = False
     _is_textual = False
     _is_from_clause = False
     _is_returns_rows = False
index e674c4b74d657ec50577a5e57796549fb194d9c1..6a7b835042dbb82095e2bd5510ad8db1ab16ce27 100644 (file)
@@ -2484,6 +2484,8 @@ class TableClause(roles.DMLTableRole, Immutable, FromClause):
 
     named_with_column = True
 
+    _is_table = True
+
     implicit_returning = False
     """:class:`_expression.TableClause`
     doesn't support having a primary key or column
@@ -3980,6 +3982,8 @@ class SelectState(util.MemoizedSlots, CompileState):
         return go
 
     def _get_froms(self, statement):
+        self._ambiguous_table_name_map = ambiguous_table_name_map = {}
+
         return self._normalize_froms(
             itertools.chain(
                 itertools.chain.from_iterable(
@@ -3997,10 +4001,16 @@ class SelectState(util.MemoizedSlots, CompileState):
                 self.from_clauses,
             ),
             check_statement=statement,
+            ambiguous_table_name_map=ambiguous_table_name_map,
         )
 
     @classmethod
-    def _normalize_froms(cls, iterable_of_froms, check_statement=None):
+    def _normalize_froms(
+        cls,
+        iterable_of_froms,
+        check_statement=None,
+        ambiguous_table_name_map=None,
+    ):
         """given an iterable of things to select FROM, reduce them to what
         would actually render in the FROM clause of a SELECT.
 
@@ -4013,6 +4023,7 @@ class SelectState(util.MemoizedSlots, CompileState):
         froms = []
 
         for item in iterable_of_froms:
+
             if item._is_subquery and item.element is check_statement:
                 raise exc.InvalidRequestError(
                     "select() construct refers to itself as a FROM"
@@ -4033,6 +4044,21 @@ class SelectState(util.MemoizedSlots, CompileState):
                 # using a list to maintain ordering
                 froms = [f for f in froms if f not in toremove]
 
+            if ambiguous_table_name_map is not None:
+                ambiguous_table_name_map.update(
+                    (
+                        fr.name,
+                        _anonymous_label.safe_construct(
+                            hash(fr.name), fr.name
+                        ),
+                    )
+                    for item in froms
+                    for fr in item._from_objects
+                    if fr._is_table
+                    and fr.schema
+                    and fr.name not in ambiguous_table_name_map
+                )
+
         return froms
 
     def _get_display_froms(
index c1228f5df30a692aad3946a975e11f0f318b8a4d..92fd29503e66f982545f33f9cd0b74f6cd1e1346 100644 (file)
@@ -624,6 +624,105 @@ class FetchLimitOffsetTest(fixtures.TablesTest):
         eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)]))
 
 
+class SameNamedSchemaTableTest(fixtures.TablesTest):
+    """tests for #7471"""
+
+    __backend__ = True
+
+    __requires__ = ("schemas",)
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "some_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            schema=config.test_schema,
+        )
+        Table(
+            "some_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column(
+                "some_table_id",
+                Integer,
+                # ForeignKey("%s.some_table.id" % config.test_schema),
+                nullable=False,
+            ),
+        )
+
+    @classmethod
+    def insert_data(cls, connection):
+        some_table, some_table_schema = cls.tables(
+            "some_table", "%s.some_table" % config.test_schema
+        )
+        connection.execute(some_table_schema.insert(), {"id": 1})
+        connection.execute(some_table.insert(), {"id": 1, "some_table_id": 1})
+
+    def test_simple_join_both_tables(self, connection):
+        some_table, some_table_schema = self.tables(
+            "some_table", "%s.some_table" % config.test_schema
+        )
+
+        eq_(
+            connection.execute(
+                select(some_table, some_table_schema).join_from(
+                    some_table,
+                    some_table_schema,
+                    some_table.c.some_table_id == some_table_schema.c.id,
+                )
+            ).first(),
+            (1, 1, 1),
+        )
+
+    def test_simple_join_whereclause_only(self, connection):
+        some_table, some_table_schema = self.tables(
+            "some_table", "%s.some_table" % config.test_schema
+        )
+
+        eq_(
+            connection.execute(
+                select(some_table)
+                .join_from(
+                    some_table,
+                    some_table_schema,
+                    some_table.c.some_table_id == some_table_schema.c.id,
+                )
+                .where(some_table.c.id == 1)
+            ).first(),
+            (1, 1),
+        )
+
+    def test_subquery(self, connection):
+        some_table, some_table_schema = self.tables(
+            "some_table", "%s.some_table" % config.test_schema
+        )
+
+        subq = (
+            select(some_table)
+            .join_from(
+                some_table,
+                some_table_schema,
+                some_table.c.some_table_id == some_table_schema.c.id,
+            )
+            .where(some_table.c.id == 1)
+            .subquery()
+        )
+
+        eq_(
+            connection.execute(
+                select(some_table, subq.c.id)
+                .join_from(
+                    some_table,
+                    subq,
+                    some_table.c.some_table_id == subq.c.id,
+                )
+                .where(some_table.c.id == 1)
+            ).first(),
+            (1, 1, 1),
+        )
+
+
 class JoinTest(fixtures.TablesTest):
     __backend__ = True
 
index d5789d0e55137995586631e6f31b7266df329acc..b42bab7d35408e4bd1dec4dac60df4136913d6f1 100644 (file)
@@ -510,6 +510,9 @@ class DefaultRequirements(SuiteRequirements):
 
         basically, PostgreSQL.
 
+        TODO: what does this mean?  all the backends have a "default"
+        schema
+
         """
         return only_on(["postgresql"])
 
index 5ea1110c6f8b3ca97ad48fc8a3a7d55b26b2c9ae..c273dbbf8767e70f56b41c9a14b879d4b387ae29 100644 (file)
@@ -5624,6 +5624,78 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL):
             render_schema_translate=True,
         )
 
+    def test_schema_non_schema_disambiguation(self):
+        """test #7471"""
+
+        t1 = table("some_table", column("id"), column("q"))
+        t2 = table("some_table", column("id"), column("p"), schema="foo")
+
+        self.assert_compile(
+            select(t1, t2),
+            "SELECT some_table_1.id, some_table_1.q, "
+            "foo.some_table.id AS id_1, foo.some_table.p "
+            "FROM some_table AS some_table_1, foo.some_table",
+        )
+
+        self.assert_compile(
+            select(t1, t2).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL),
+            # the original "tablename_colname" label is preserved despite
+            # the alias of some_table
+            "SELECT some_table_1.id AS some_table_id, some_table_1.q AS "
+            "some_table_q, foo.some_table.id AS foo_some_table_id, "
+            "foo.some_table.p AS foo_some_table_p "
+            "FROM some_table AS some_table_1, foo.some_table",
+        )
+
+        self.assert_compile(
+            select(t1, t2).join_from(t1, t2, t1.c.id == t2.c.id),
+            "SELECT some_table_1.id, some_table_1.q, "
+            "foo.some_table.id AS id_1, foo.some_table.p "
+            "FROM some_table AS some_table_1 "
+            "JOIN foo.some_table ON some_table_1.id = foo.some_table.id",
+        )
+
+        self.assert_compile(
+            select(t1, t2).where(t1.c.id == t2.c.id),
+            "SELECT some_table_1.id, some_table_1.q, "
+            "foo.some_table.id AS id_1, foo.some_table.p "
+            "FROM some_table AS some_table_1, foo.some_table "
+            "WHERE some_table_1.id = foo.some_table.id",
+        )
+
+        self.assert_compile(
+            select(t1).where(t1.c.id == t2.c.id),
+            "SELECT some_table_1.id, some_table_1.q "
+            "FROM some_table AS some_table_1, foo.some_table "
+            "WHERE some_table_1.id = foo.some_table.id",
+        )
+
+        subq = select(t1).where(t1.c.id == t2.c.id).subquery()
+        self.assert_compile(
+            select(t2).select_from(t2).join(subq, t2.c.id == subq.c.id),
+            "SELECT foo.some_table.id, foo.some_table.p "
+            "FROM foo.some_table JOIN "
+            "(SELECT some_table_1.id AS id, some_table_1.q AS q "
+            "FROM some_table AS some_table_1, foo.some_table "
+            "WHERE some_table_1.id = foo.some_table.id) AS anon_1 "
+            "ON foo.some_table.id = anon_1.id",
+        )
+
+        self.assert_compile(
+            select(t1, subq.c.id)
+            .select_from(t1)
+            .join(subq, t1.c.id == subq.c.id),
+            # some_table is only aliased inside the subquery.  this is not
+            # any challenge for the compiler, just checking as this is a new
+            # source of aliasing.
+            "SELECT some_table.id, some_table.q, anon_1.id AS id_1 "
+            "FROM some_table "
+            "JOIN (SELECT some_table_1.id AS id, some_table_1.q AS q "
+            "FROM some_table AS some_table_1, foo.some_table "
+            "WHERE some_table_1.id = foo.some_table.id) AS anon_1 "
+            "ON some_table.id = anon_1.id",
+        )
+
     def test_alias(self):
         a = alias(table4, "remtable")
         self.assert_compile(