From: Mike Bayer Date: Thu, 9 Mar 2023 16:49:46 +0000 (-0500) Subject: denormalize "public" schema to "PUBLIC" X-Git-Tag: rel_2_0_6~8 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=33ae862c054c4ab167aeab8cdc499b863c0f70a9;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git denormalize "public" schema to "PUBLIC" Fixed reflection bug where Oracle "name normalize" would not work correctly for reflection of symbols that are in the "PUBLIC" schema, such as synonyms, meaning the PUBLIC name could not be indicated as lower case on the Python side for the :paramref:`_schema.Table.schema` argument. Using uppercase "PUBLIC" would work, but would then lead to awkward SQL queries including a quoted ``"PUBLIC"`` name as well as indexing the table under uppercase "PUBLIC", which was inconsistent. Fixes: #9459 Change-Id: I989bd1e794a5b5ac9aae4f4a8702f14c56cd74c2 --- diff --git a/doc/build/changelog/unreleased_20/9459.rst b/doc/build/changelog/unreleased_20/9459.rst new file mode 100644 index 0000000000..d4704cf013 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9459.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, oracle + :tickets: 9459 + + Fixed reflection bug where Oracle "name normalize" would not work correctly + for reflection of symbols that are in the "PUBLIC" schema, such as + synonyms, meaning the PUBLIC name could not be indicated as lower case on + the Python side for the :paramref:`_schema.Table.schema` argument. Using + uppercase "PUBLIC" would work, but would then lead to awkward SQL queries + including a quoted ``"PUBLIC"`` name as well as indexing the table under + uppercase "PUBLIC", which was inconsistent. diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index aa289111e3..16990c751d 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1640,7 +1640,7 @@ class OracleDialect(default.DefaultDialect): params = { "table_name": self.denormalize_name(table_name), - "owner": self.denormalize_name(schema), + "owner": self.denormalize_schema_name(schema), } cursor = self._execute_reflection( connection, @@ -1661,9 +1661,9 @@ class OracleDialect(default.DefaultDialect): query = select(dictionary.all_sequences.c.sequence_name).where( dictionary.all_sequences.c.sequence_name - == self.denormalize_name(sequence_name), + == self.denormalize_schema_name(sequence_name), dictionary.all_sequences.c.sequence_owner - == self.denormalize_name(schema), + == self.denormalize_schema_name(schema), ) cursor = self._execute_reflection( @@ -1678,13 +1678,23 @@ class OracleDialect(default.DefaultDialect): ).scalar() ) + def denormalize_schema_name(self, name): + # look for quoted_name + force = getattr(name, "quote", None) + if force is None and name == "public": + # look for case insensitive, no quoting specified, "public" + return "PUBLIC" + return super().denormalize_name(name) + @reflection.flexi_cache( ("schema", InternalTraversal.dp_string), ("filter_names", InternalTraversal.dp_string_list), ("dblink", InternalTraversal.dp_string), ) def _get_synonyms(self, connection, schema, filter_names, dblink, **kw): - owner = self.denormalize_name(schema or self.default_schema_name) + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) has_filter_names, params = self._prepare_filter_names(filter_names) query = select( @@ -1775,7 +1785,9 @@ class OracleDialect(default.DefaultDialect): def _get_all_objects( self, connection, schema, scope, kind, filter_names, dblink, **kw ): - owner = self.denormalize_name(schema or self.default_schema_name) + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) has_filter_names, params = self._prepare_filter_names(filter_names) has_mat_views = False @@ -1864,7 +1876,7 @@ class OracleDialect(default.DefaultDialect): if schema is None: schema = self.default_schema_name - den_schema = self.denormalize_name(schema) + den_schema = self.denormalize_schema_name(schema) if kw.get("oracle_resolve_synonyms", False): tables = ( select( @@ -1935,7 +1947,7 @@ class OracleDialect(default.DefaultDialect): @reflection.cache def get_temp_table_names(self, connection, dblink=None, **kw): """Supported kw arguments are: ``dblink`` to reflect via a db link.""" - schema = self.denormalize_name(self.default_schema_name) + schema = self.denormalize_schema_name(self.default_schema_name) query = select(dictionary.all_tables.c.table_name) if self.exclude_tablespaces: @@ -1964,7 +1976,8 @@ class OracleDialect(default.DefaultDialect): schema = self.default_schema_name query = select(dictionary.all_mviews.c.mview_name).where( - dictionary.all_mviews.c.owner == self.denormalize_name(schema) + dictionary.all_mviews.c.owner + == self.denormalize_schema_name(schema) ) result = self._execute_reflection( connection, query, dblink, returns_long=False @@ -1981,7 +1994,8 @@ class OracleDialect(default.DefaultDialect): schema = self.default_schema_name query = select(dictionary.all_views.c.view_name).where( - dictionary.all_views.c.owner == self.denormalize_name(schema) + dictionary.all_views.c.owner + == self.denormalize_schema_name(schema) ) result = self._execute_reflection( connection, query, dblink, returns_long=False @@ -1995,7 +2009,7 @@ class OracleDialect(default.DefaultDialect): schema = self.default_schema_name query = select(dictionary.all_sequences.c.sequence_name).where( dictionary.all_sequences.c.sequence_owner - == self.denormalize_name(schema) + == self.denormalize_schema_name(schema) ) result = self._execute_reflection( @@ -2095,7 +2109,9 @@ class OracleDialect(default.DefaultDialect): """Supported kw arguments are: ``dblink`` to reflect via a db link; ``oracle_resolve_synonyms`` to resolve names to synonyms """ - owner = self.denormalize_name(schema or self.default_schema_name) + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) has_filter_names, params = self._prepare_filter_names(filter_names) has_mat_views = False @@ -2270,7 +2286,9 @@ class OracleDialect(default.DefaultDialect): """Supported kw arguments are: ``dblink`` to reflect via a db link; ``oracle_resolve_synonyms`` to resolve names to synonyms """ - owner = self.denormalize_name(schema or self.default_schema_name) + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) query = self._column_query(owner) if ( @@ -2505,7 +2523,9 @@ class OracleDialect(default.DefaultDialect): """Supported kw arguments are: ``dblink`` to reflect via a db link; ``oracle_resolve_synonyms`` to resolve names to synonyms """ - owner = self.denormalize_name(schema or self.default_schema_name) + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) has_filter_names, params = self._prepare_filter_names(filter_names) query = self._comment_query(owner, scope, kind, has_filter_names) @@ -2586,7 +2606,9 @@ class OracleDialect(default.DefaultDialect): ("all_objects", InternalTraversal.dp_string_list), ) def _get_indexes_rows(self, connection, schema, dblink, all_objects, **kw): - owner = self.denormalize_name(schema or self.default_schema_name) + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) query = self._index_query(owner) @@ -2755,7 +2777,9 @@ class OracleDialect(default.DefaultDialect): def _get_all_constraint_rows( self, connection, schema, dblink, all_objects, **kw ): - owner = self.denormalize_name(schema or self.default_schema_name) + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) query = self._constraint_query(owner) # since the result is cached a list must be created @@ -2859,7 +2883,9 @@ class OracleDialect(default.DefaultDialect): resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - owner = self.denormalize_name(schema or self.default_schema_name) + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) all_remote_owners = set() fkeys = defaultdict(dict) @@ -3080,7 +3106,9 @@ class OracleDialect(default.DefaultDialect): view_name = row_dict["table_name"] name = self.denormalize_name(view_name) - owner = self.denormalize_name(schema or self.default_schema_name) + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) query = ( select(dictionary.all_views.c.text) .where( diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py index ae50264f7b..ac58d36948 100644 --- a/test/dialect/oracle/test_reflection.py +++ b/test/dialect/oracle/test_reflection.py @@ -1597,3 +1597,120 @@ drop table %(schema)sparent; connection.exec_driver_sql("DROP SYNONYM s1") connection.exec_driver_sql("DROP SYNONYM s2") connection.exec_driver_sql("DROP SYNONYM s3") + + @testing.fixture + def public_synonym_fixture(self, connection): + foo_syn = f"foo_syn_{config.ident}" + + connection.exec_driver_sql("CREATE TABLE foobar (id integer)") + + try: + connection.exec_driver_sql( + f"CREATE PUBLIC SYNONYM {foo_syn} for foobar" + ) + except: + # assume the synonym exists is the main problem here. + # since --dropfirst will not get this synonym, drop it directly + # for the next run. + try: + connection.exec_driver_sql(f"DROP PUBLIC SYNONYM {foo_syn}") + except: + pass + + raise + + try: + yield foo_syn + finally: + try: + connection.exec_driver_sql(f"DROP PUBLIC SYNONYM {foo_syn}") + except: + pass + try: + connection.exec_driver_sql("DROP TABLE foobar") + except: + pass + + @testing.variation( + "case_convention", ["uppercase", "lowercase", "mixedcase"] + ) + def test_public_synonym_fetch( + self, + connection, + public_synonym_fixture, + case_convention: testing.Variation, + ): + """test #9459""" + + foo_syn = public_synonym_fixture + + if case_convention.uppercase: + public = "PUBLIC" + elif case_convention.lowercase: + public = "public" + elif case_convention.mixedcase: + public = "Public" + else: + case_convention.fail() + + syns = connection.dialect._get_synonyms(connection, public, None, None) + + if case_convention.mixedcase: + assert not syns + return + + syns_by_name = {syn["synonym_name"]: syn for syn in syns} + eq_( + syns_by_name[foo_syn.upper()], + { + "synonym_name": foo_syn.upper(), + "table_name": "FOOBAR", + "table_owner": connection.dialect.default_schema_name.upper(), + "db_link": None, + }, + ) + + @testing.variation( + "case_convention", ["uppercase", "lowercase", "mixedcase"] + ) + def test_public_synonym_resolve_table( + self, + connection, + public_synonym_fixture, + case_convention: testing.Variation, + ): + """test #9459""" + + foo_syn = public_synonym_fixture + + if case_convention.uppercase: + public = "PUBLIC" + elif case_convention.lowercase: + public = "public" + elif case_convention.mixedcase: + public = "Public" + else: + case_convention.fail() + + if case_convention.mixedcase: + with expect_raises(exc.NoSuchTableError): + cols = inspect(connection).get_columns( + foo_syn, schema=public, oracle_resolve_synonyms=True + ) + else: + cols = inspect(connection).get_columns( + foo_syn, schema=public, oracle_resolve_synonyms=True + ) + + eq_( + cols, + [ + { + "name": "id", + "type": testing.eq_type_affinity(INTEGER), + "nullable": True, + "default": None, + "comment": None, + } + ], + )