]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
denormalize "public" schema to "PUBLIC"
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 9 Mar 2023 16:49:46 +0000 (11:49 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 9 Mar 2023 21:49:38 +0000 (16:49 -0500)
Fixed reflection bug where Oracle "name normalize" would not work correctly
for reflection of symbols that are in the "PUBLIC" schema, such as
synonyms, meaning the PUBLIC name could not be indicated as lower case on
the Python side for the :paramref:`_schema.Table.schema` argument. Using
uppercase "PUBLIC" would work, but would then lead to awkward SQL queries
including a quoted ``"PUBLIC"`` name as well as indexing the table under
uppercase "PUBLIC", which was inconsistent.

Fixes: #9459
Change-Id: I989bd1e794a5b5ac9aae4f4a8702f14c56cd74c2

doc/build/changelog/unreleased_20/9459.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/base.py
test/dialect/oracle/test_reflection.py

diff --git a/doc/build/changelog/unreleased_20/9459.rst b/doc/build/changelog/unreleased_20/9459.rst
new file mode 100644 (file)
index 0000000..d4704cf
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, oracle
+    :tickets: 9459
+
+    Fixed reflection bug where Oracle "name normalize" would not work correctly
+    for reflection of symbols that are in the "PUBLIC" schema, such as
+    synonyms, meaning the PUBLIC name could not be indicated as lower case on
+    the Python side for the :paramref:`_schema.Table.schema` argument. Using
+    uppercase "PUBLIC" would work, but would then lead to awkward SQL queries
+    including a quoted ``"PUBLIC"`` name as well as indexing the table under
+    uppercase "PUBLIC", which was inconsistent.
index aa289111e35ce6c0a92bc751e589738eec6fa4e4..16990c751dc02cf3be9ffc7ccffc148bd17b1704 100644 (file)
@@ -1640,7 +1640,7 @@ class OracleDialect(default.DefaultDialect):
 
         params = {
             "table_name": self.denormalize_name(table_name),
-            "owner": self.denormalize_name(schema),
+            "owner": self.denormalize_schema_name(schema),
         }
         cursor = self._execute_reflection(
             connection,
@@ -1661,9 +1661,9 @@ class OracleDialect(default.DefaultDialect):
 
         query = select(dictionary.all_sequences.c.sequence_name).where(
             dictionary.all_sequences.c.sequence_name
-            == self.denormalize_name(sequence_name),
+            == self.denormalize_schema_name(sequence_name),
             dictionary.all_sequences.c.sequence_owner
-            == self.denormalize_name(schema),
+            == self.denormalize_schema_name(schema),
         )
 
         cursor = self._execute_reflection(
@@ -1678,13 +1678,23 @@ class OracleDialect(default.DefaultDialect):
             ).scalar()
         )
 
+    def denormalize_schema_name(self, name):
+        # look for quoted_name
+        force = getattr(name, "quote", None)
+        if force is None and name == "public":
+            # look for case insensitive, no quoting specified, "public"
+            return "PUBLIC"
+        return super().denormalize_name(name)
+
     @reflection.flexi_cache(
         ("schema", InternalTraversal.dp_string),
         ("filter_names", InternalTraversal.dp_string_list),
         ("dblink", InternalTraversal.dp_string),
     )
     def _get_synonyms(self, connection, schema, filter_names, dblink, **kw):
-        owner = self.denormalize_name(schema or self.default_schema_name)
+        owner = self.denormalize_schema_name(
+            schema or self.default_schema_name
+        )
 
         has_filter_names, params = self._prepare_filter_names(filter_names)
         query = select(
@@ -1775,7 +1785,9 @@ class OracleDialect(default.DefaultDialect):
     def _get_all_objects(
         self, connection, schema, scope, kind, filter_names, dblink, **kw
     ):
-        owner = self.denormalize_name(schema or self.default_schema_name)
+        owner = self.denormalize_schema_name(
+            schema or self.default_schema_name
+        )
 
         has_filter_names, params = self._prepare_filter_names(filter_names)
         has_mat_views = False
@@ -1864,7 +1876,7 @@ class OracleDialect(default.DefaultDialect):
         if schema is None:
             schema = self.default_schema_name
 
-        den_schema = self.denormalize_name(schema)
+        den_schema = self.denormalize_schema_name(schema)
         if kw.get("oracle_resolve_synonyms", False):
             tables = (
                 select(
@@ -1935,7 +1947,7 @@ class OracleDialect(default.DefaultDialect):
     @reflection.cache
     def get_temp_table_names(self, connection, dblink=None, **kw):
         """Supported kw arguments are: ``dblink`` to reflect via a db link."""
-        schema = self.denormalize_name(self.default_schema_name)
+        schema = self.denormalize_schema_name(self.default_schema_name)
 
         query = select(dictionary.all_tables.c.table_name)
         if self.exclude_tablespaces:
@@ -1964,7 +1976,8 @@ class OracleDialect(default.DefaultDialect):
             schema = self.default_schema_name
 
         query = select(dictionary.all_mviews.c.mview_name).where(
-            dictionary.all_mviews.c.owner == self.denormalize_name(schema)
+            dictionary.all_mviews.c.owner
+            == self.denormalize_schema_name(schema)
         )
         result = self._execute_reflection(
             connection, query, dblink, returns_long=False
@@ -1981,7 +1994,8 @@ class OracleDialect(default.DefaultDialect):
             schema = self.default_schema_name
 
         query = select(dictionary.all_views.c.view_name).where(
-            dictionary.all_views.c.owner == self.denormalize_name(schema)
+            dictionary.all_views.c.owner
+            == self.denormalize_schema_name(schema)
         )
         result = self._execute_reflection(
             connection, query, dblink, returns_long=False
@@ -1995,7 +2009,7 @@ class OracleDialect(default.DefaultDialect):
             schema = self.default_schema_name
         query = select(dictionary.all_sequences.c.sequence_name).where(
             dictionary.all_sequences.c.sequence_owner
-            == self.denormalize_name(schema)
+            == self.denormalize_schema_name(schema)
         )
 
         result = self._execute_reflection(
@@ -2095,7 +2109,9 @@ class OracleDialect(default.DefaultDialect):
         """Supported kw arguments are: ``dblink`` to reflect via a db link;
         ``oracle_resolve_synonyms`` to resolve names to synonyms
         """
-        owner = self.denormalize_name(schema or self.default_schema_name)
+        owner = self.denormalize_schema_name(
+            schema or self.default_schema_name
+        )
 
         has_filter_names, params = self._prepare_filter_names(filter_names)
         has_mat_views = False
@@ -2270,7 +2286,9 @@ class OracleDialect(default.DefaultDialect):
         """Supported kw arguments are: ``dblink`` to reflect via a db link;
         ``oracle_resolve_synonyms`` to resolve names to synonyms
         """
-        owner = self.denormalize_name(schema or self.default_schema_name)
+        owner = self.denormalize_schema_name(
+            schema or self.default_schema_name
+        )
         query = self._column_query(owner)
 
         if (
@@ -2505,7 +2523,9 @@ class OracleDialect(default.DefaultDialect):
         """Supported kw arguments are: ``dblink`` to reflect via a db link;
         ``oracle_resolve_synonyms`` to resolve names to synonyms
         """
-        owner = self.denormalize_name(schema or self.default_schema_name)
+        owner = self.denormalize_schema_name(
+            schema or self.default_schema_name
+        )
         has_filter_names, params = self._prepare_filter_names(filter_names)
         query = self._comment_query(owner, scope, kind, has_filter_names)
 
@@ -2586,7 +2606,9 @@ class OracleDialect(default.DefaultDialect):
         ("all_objects", InternalTraversal.dp_string_list),
     )
     def _get_indexes_rows(self, connection, schema, dblink, all_objects, **kw):
-        owner = self.denormalize_name(schema or self.default_schema_name)
+        owner = self.denormalize_schema_name(
+            schema or self.default_schema_name
+        )
 
         query = self._index_query(owner)
 
@@ -2755,7 +2777,9 @@ class OracleDialect(default.DefaultDialect):
     def _get_all_constraint_rows(
         self, connection, schema, dblink, all_objects, **kw
     ):
-        owner = self.denormalize_name(schema or self.default_schema_name)
+        owner = self.denormalize_schema_name(
+            schema or self.default_schema_name
+        )
         query = self._constraint_query(owner)
 
         # since the result is cached a list must be created
@@ -2859,7 +2883,9 @@ class OracleDialect(default.DefaultDialect):
 
         resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
 
-        owner = self.denormalize_name(schema or self.default_schema_name)
+        owner = self.denormalize_schema_name(
+            schema or self.default_schema_name
+        )
 
         all_remote_owners = set()
         fkeys = defaultdict(dict)
@@ -3080,7 +3106,9 @@ class OracleDialect(default.DefaultDialect):
                 view_name = row_dict["table_name"]
 
         name = self.denormalize_name(view_name)
-        owner = self.denormalize_name(schema or self.default_schema_name)
+        owner = self.denormalize_schema_name(
+            schema or self.default_schema_name
+        )
         query = (
             select(dictionary.all_views.c.text)
             .where(
index ae50264f7beae9c3a3542efa2388ee9789e8f653..ac58d369488ec047af76bb25d7a72d59692e5bcb 100644 (file)
@@ -1597,3 +1597,120 @@ drop table %(schema)sparent;
             connection.exec_driver_sql("DROP SYNONYM s1")
             connection.exec_driver_sql("DROP SYNONYM s2")
             connection.exec_driver_sql("DROP SYNONYM s3")
+
+    @testing.fixture
+    def public_synonym_fixture(self, connection):
+        foo_syn = f"foo_syn_{config.ident}"
+
+        connection.exec_driver_sql("CREATE TABLE foobar (id integer)")
+
+        try:
+            connection.exec_driver_sql(
+                f"CREATE PUBLIC SYNONYM {foo_syn} for foobar"
+            )
+        except:
+            # assume the synonym exists is the main problem here.
+            # since --dropfirst will not get this synonym, drop it directly
+            # for the next run.
+            try:
+                connection.exec_driver_sql(f"DROP PUBLIC SYNONYM {foo_syn}")
+            except:
+                pass
+
+            raise
+
+        try:
+            yield foo_syn
+        finally:
+            try:
+                connection.exec_driver_sql(f"DROP PUBLIC SYNONYM {foo_syn}")
+            except:
+                pass
+            try:
+                connection.exec_driver_sql("DROP TABLE foobar")
+            except:
+                pass
+
+    @testing.variation(
+        "case_convention", ["uppercase", "lowercase", "mixedcase"]
+    )
+    def test_public_synonym_fetch(
+        self,
+        connection,
+        public_synonym_fixture,
+        case_convention: testing.Variation,
+    ):
+        """test #9459"""
+
+        foo_syn = public_synonym_fixture
+
+        if case_convention.uppercase:
+            public = "PUBLIC"
+        elif case_convention.lowercase:
+            public = "public"
+        elif case_convention.mixedcase:
+            public = "Public"
+        else:
+            case_convention.fail()
+
+        syns = connection.dialect._get_synonyms(connection, public, None, None)
+
+        if case_convention.mixedcase:
+            assert not syns
+            return
+
+        syns_by_name = {syn["synonym_name"]: syn for syn in syns}
+        eq_(
+            syns_by_name[foo_syn.upper()],
+            {
+                "synonym_name": foo_syn.upper(),
+                "table_name": "FOOBAR",
+                "table_owner": connection.dialect.default_schema_name.upper(),
+                "db_link": None,
+            },
+        )
+
+    @testing.variation(
+        "case_convention", ["uppercase", "lowercase", "mixedcase"]
+    )
+    def test_public_synonym_resolve_table(
+        self,
+        connection,
+        public_synonym_fixture,
+        case_convention: testing.Variation,
+    ):
+        """test #9459"""
+
+        foo_syn = public_synonym_fixture
+
+        if case_convention.uppercase:
+            public = "PUBLIC"
+        elif case_convention.lowercase:
+            public = "public"
+        elif case_convention.mixedcase:
+            public = "Public"
+        else:
+            case_convention.fail()
+
+        if case_convention.mixedcase:
+            with expect_raises(exc.NoSuchTableError):
+                cols = inspect(connection).get_columns(
+                    foo_syn, schema=public, oracle_resolve_synonyms=True
+                )
+        else:
+            cols = inspect(connection).get_columns(
+                foo_syn, schema=public, oracle_resolve_synonyms=True
+            )
+
+            eq_(
+                cols,
+                [
+                    {
+                        "name": "id",
+                        "type": testing.eq_type_affinity(INTEGER),
+                        "nullable": True,
+                        "default": None,
+                        "comment": None,
+                    }
+                ],
+            )