From: Mike Bayer Date: Mon, 1 Jun 2020 00:34:03 +0000 (-0400) Subject: Support multiple dotted sections in mssql schema names X-Git-Tag: rel_1_4_0b1~288^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a7a19f292451e10aef489d87df27be7f58f831a8;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support multiple dotted sections in mssql schema names Refined the logic used by the SQL Server dialect to interpret multi-part schema names that contain many dots, to not actually lose any dots if the name does not have bracking or quoting used, and additionally to support a "dbname" token that has many parts including that it may have multiple, independently-bracketed sections. This fix addresses #5364 to some degree but probably does not resolve it fully. References: #5364 Fixes: #5366 Change-Id: I460cd74ce443efb35fb63b6864f00c6d81422688 --- diff --git a/doc/build/changelog/unreleased_13/5366.rst b/doc/build/changelog/unreleased_13/5366.rst new file mode 100644 index 0000000000..ff694397f1 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5366.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, mssql + :tickets: 5366, 5364 + + Refined the logic used by the SQL Server dialect to interpret multi-part + schema names that contain many dots, to not actually lose any dots if the + name does not have bracking or quoting used, and additionally to support a + "dbname" token that has many parts including that it may have multiple, + independently-bracketed sections. + + diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 5e07045978..bbf44906c9 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2287,8 +2287,7 @@ def _switch_db(dbname, connection, fn, *arg, **kw): current_db = connection.exec_driver_sql("select db_name()").scalar() if current_db != dbname: connection.exec_driver_sql( - "use %s" - % connection.dialect.identifier_preparer.quote_schema(dbname) + "use %s" % connection.dialect.identifier_preparer.quote(dbname) ) try: return fn(*arg, **kw) @@ -2296,9 +2295,7 @@ def _switch_db(dbname, connection, fn, *arg, **kw): if dbname and current_db != dbname: connection.exec_driver_sql( "use %s" - % connection.dialect.identifier_preparer.quote_schema( - current_db - ) + % connection.dialect.identifier_preparer.quote(current_db) ) @@ -2311,33 +2308,62 @@ def _owner_plus_db(dialect, schema): return None, schema +_memoized_schema = util.LRUCache() + + def _schema_elements(schema): if isinstance(schema, quoted_name) and schema.quote: return None, schema + if schema in _memoized_schema: + return _memoized_schema[schema] + + # tests for this function are in: + # test/dialect/mssql/test_reflection.py -> + # OwnerPlusDBTest.test_owner_database_pairs + # test/dialect/mssql/test_compiler.py -> test_force_schema_* + # test/dialect/mssql/test_compiler.py -> test_schema_many_tokens_* + # + push = [] symbol = "" bracket = False + has_brackets = False for token in re.split(r"(\[|\]|\.)", schema): if not token: continue if token == "[": bracket = True + has_brackets = True elif token == "]": bracket = False elif not bracket and token == ".": - push.append(symbol) + if has_brackets: + push.append("[%s]" % symbol) + else: + push.append(symbol) symbol = "" + has_brackets = False else: symbol += token if symbol: push.append(symbol) if len(push) > 1: - return push[0], "".join(push[1:]) + dbname, owner = ".".join(push[0:-1]), push[-1] + + # test for internal brackets + if re.match(r".*\].*\[.*", dbname[1:-1]): + dbname = quoted_name(dbname, quote=False) + else: + dbname = dbname.lstrip("[").rstrip("]") + elif len(push): - return None, push[0] + dbname, owner = None, push[0] else: - return None, None + dbname, owner = None, None + + _memoized_schema[schema] = dbname, owner + return dbname, owner class MSDialect(default.DefaultDialect): diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index b7a06c8e3e..25af3240ec 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -22,6 +22,7 @@ from sqlalchemy import union from sqlalchemy import UniqueConstraint from sqlalchemy import update from sqlalchemy.dialects import mssql +from sqlalchemy.dialects.mssql import base as mssql_base from sqlalchemy.dialects.mssql import mxodbc from sqlalchemy.dialects.mssql.base import try_cast from sqlalchemy.sql import column @@ -525,6 +526,42 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): checkpositional=("bar",), ) + def test_schema_many_tokens_one(self): + metadata = MetaData() + tbl = Table( + "test", + metadata, + Column("id", Integer, primary_key=True), + schema="abc.def.efg.hij", + ) + + # for now, we don't really know what the above means, at least + # don't lose the dot + self.assert_compile( + select([tbl]), + "SELECT [abc.def.efg].hij.test.id FROM [abc.def.efg].hij.test", + ) + + dbname, owner = mssql_base._schema_elements("abc.def.efg.hij") + eq_(dbname, "abc.def.efg") + assert not isinstance(dbname, quoted_name) + eq_(owner, "hij") + + def test_schema_many_tokens_two(self): + metadata = MetaData() + tbl = Table( + "test", + metadata, + Column("id", Integer, primary_key=True), + schema="[abc].[def].[efg].[hij]", + ) + + self.assert_compile( + select([tbl]), + "SELECT [abc].[def].[efg].hij.test.id " + "FROM [abc].[def].[efg].hij.test", + ) + def test_force_schema_quoted_name_w_dot_case_insensitive(self): metadata = MetaData() tbl = Table( diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index 5328513e4c..176d3d2ecb 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -484,56 +484,65 @@ class OwnerPlusDBTest(fixtures.TestBase): ) eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")]) - def test_owner_database_pairs(self): + @testing.combinations( + ("foo", None, "foo", "use foo"), + ("foo.bar", "foo", "bar", "use foo"), + ("Foo.Bar", "Foo", "Bar", "use [Foo]"), + ("[Foo.Bar]", None, "Foo.Bar", "use [Foo.Bar]"), + ("[Foo.Bar].[bat]", "Foo.Bar", "bat", "use [Foo.Bar]"), + ( + "[foo].]do something; select [foo", + "foo", + "do something; select foo", + "use foo", + ), + ( + "something; select [foo].bar", + "something; select foo", + "bar", + "use [something; select foo]", + ), + ( + "[abc].[def].[efg].[hij]", + "[abc].[def].[efg]", + "hij", + "use [abc].[def].[efg]", + ), + ("abc.def.efg.hij", "abc.def.efg", "hij", "use [abc.def.efg]"), + ) + def test_owner_database_pairs( + self, identifier, expected_schema, expected_owner, use_stmt + ): dialect = mssql.dialect() - for identifier, expected_schema, expected_owner, use_stmt in [ - ("foo", None, "foo", "use foo"), - ("foo.bar", "foo", "bar", "use foo"), - ("Foo.Bar", "Foo", "Bar", "use [Foo]"), - ("[Foo.Bar]", None, "Foo.Bar", "use [Foo].[Bar]"), - ("[Foo.Bar].[bat]", "Foo.Bar", "bat", "use [Foo].[Bar]"), - ( - "[foo].]do something; select [foo", - "foo", - "do something; select foo", - "use foo", - ), - ( - "something; select [foo].bar", - "something; select foo", - "bar", - "use [something; select foo]", + schema, owner = base._owner_plus_db(dialect, identifier) + + eq_(owner, expected_owner) + eq_(schema, expected_schema) + + mock_connection = mock.Mock( + dialect=dialect, + exec_driver_sql=mock.Mock( + return_value=mock.Mock( + scalar=mock.Mock(return_value="Some Database") + ) ), - ]: - schema, owner = base._owner_plus_db(dialect, identifier) - - eq_(owner, expected_owner) - eq_(schema, expected_schema) - - mock_connection = mock.Mock( - dialect=dialect, - exec_driver_sql=mock.Mock( - return_value=mock.Mock( - scalar=mock.Mock(return_value="Some ] Database") - ) - ), + ) + mock_lambda = mock.Mock() + base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar") + if schema is None: + eq_(mock_connection.mock_calls, []) + else: + eq_( + mock_connection.mock_calls, + [ + mock.call.exec_driver_sql("select db_name()"), + mock.call.exec_driver_sql(use_stmt), + mock.call.exec_driver_sql("use [Some Database]"), + ], ) - mock_lambda = mock.Mock() - base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar") - if schema is None: - eq_(mock_connection.mock_calls, []) - else: - eq_( - mock_connection.mock_calls, - [ - mock.call.exec_driver_sql("select db_name()"), - mock.call.exec_driver_sql(use_stmt), - mock.call.exec_driver_sql("use [Some Database]"), - ], - ) - eq_( - mock_connection.exec_driver_sql.return_value.mock_calls, - [mock.call.scalar()], - ) - eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")]) + eq_( + mock_connection.exec_driver_sql.return_value.mock_calls, + [mock.call.scalar()], + ) + eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])