From: Mike Bayer Date: Mon, 1 Jun 2020 00:34:03 +0000 (-0400) Subject: Support multiple dotted sections in mssql schema names X-Git-Tag: rel_1_3_18~15^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=dd0c170151583ad38183a8f833d8adc9815d2902;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support multiple dotted sections in mssql schema names Refined the logic used by the SQL Server dialect to interpret multi-part schema names that contain many dots, to not actually lose any dots if the name does not have bracking or quoting used, and additionally to support a "dbname" token that has many parts including that it may have multiple, independently-bracketed sections. This fix addresses #5364 to some degree but probably does not resolve it fully. References: #5364 Fixes: #5366 Change-Id: I460cd74ce443efb35fb63b6864f00c6d81422688 (cherry picked from commit 9aff0102813900fd7bc2120df5e1cfa169edb44f) --- diff --git a/doc/build/changelog/unreleased_13/5366.rst b/doc/build/changelog/unreleased_13/5366.rst new file mode 100644 index 0000000000..ff694397f1 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5366.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, mssql + :tickets: 5366, 5364 + + Refined the logic used by the SQL Server dialect to interpret multi-part + schema names that contain many dots, to not actually lose any dots if the + name does not have bracking or quoting used, and additionally to support a + "dbname" token that has many parts including that it may have multiple, + independently-bracketed sections. + + diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 496df4dc27..4f673fba14 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2211,8 +2211,7 @@ def _switch_db(dbname, connection, fn, *arg, **kw): current_db = connection.scalar("select db_name()") if current_db != dbname: connection.execute( - "use %s" - % connection.dialect.identifier_preparer.quote_schema(dbname) + "use %s" % connection.dialect.identifier_preparer.quote(dbname) ) try: return fn(*arg, **kw) @@ -2220,9 +2219,7 @@ def _switch_db(dbname, connection, fn, *arg, **kw): if dbname and current_db != dbname: connection.execute( "use %s" - % connection.dialect.identifier_preparer.quote_schema( - current_db - ) + % connection.dialect.identifier_preparer.quote(current_db) ) @@ -2235,33 +2232,62 @@ def _owner_plus_db(dialect, schema): return None, schema +_memoized_schema = util.LRUCache() + + def _schema_elements(schema): if isinstance(schema, quoted_name) and schema.quote: return None, schema + if schema in _memoized_schema: + return _memoized_schema[schema] + + # tests for this function are in: + # test/dialect/mssql/test_reflection.py -> + # OwnerPlusDBTest.test_owner_database_pairs + # test/dialect/mssql/test_compiler.py -> test_force_schema_* + # test/dialect/mssql/test_compiler.py -> test_schema_many_tokens_* + # + push = [] symbol = "" bracket = False + has_brackets = False for token in re.split(r"(\[|\]|\.)", schema): if not token: continue if token == "[": bracket = True + has_brackets = True elif token == "]": bracket = False elif not bracket and token == ".": - push.append(symbol) + if has_brackets: + push.append("[%s]" % symbol) + else: + push.append(symbol) symbol = "" + has_brackets = False else: symbol += token if symbol: push.append(symbol) if len(push) > 1: - return push[0], "".join(push[1:]) + dbname, owner = ".".join(push[0:-1]), push[-1] + + # test for internal brackets + if re.match(r".*\].*\[.*", dbname[1:-1]): + dbname = quoted_name(dbname, quote=False) + else: + dbname = dbname.lstrip("[").rstrip("]") + elif len(push): - return None, push[0] + dbname, owner = None, push[0] else: - return None, None + dbname, owner = None, None + + _memoized_schema[schema] = dbname, owner + return dbname, owner class MSDialect(default.DefaultDialect): diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index 1865fd589b..3656ed9b30 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -22,6 +22,7 @@ from sqlalchemy import union from sqlalchemy import UniqueConstraint from sqlalchemy import update from sqlalchemy.dialects import mssql +from sqlalchemy.dialects.mssql import base as mssql_base from sqlalchemy.dialects.mssql import mxodbc from sqlalchemy.dialects.mssql.base import try_cast from sqlalchemy.sql import column @@ -409,6 +410,42 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): checkpositional=("bar",), ) + def test_schema_many_tokens_one(self): + metadata = MetaData() + tbl = Table( + "test", + metadata, + Column("id", Integer, primary_key=True), + schema="abc.def.efg.hij", + ) + + # for now, we don't really know what the above means, at least + # don't lose the dot + self.assert_compile( + select([tbl]), + "SELECT [abc.def.efg].hij.test.id FROM [abc.def.efg].hij.test", + ) + + dbname, owner = mssql_base._schema_elements("abc.def.efg.hij") + eq_(dbname, "abc.def.efg") + assert not isinstance(dbname, quoted_name) + eq_(owner, "hij") + + def test_schema_many_tokens_two(self): + metadata = MetaData() + tbl = Table( + "test", + metadata, + Column("id", Integer, primary_key=True), + schema="[abc].[def].[efg].[hij]", + ) + + self.assert_compile( + select([tbl]), + "SELECT [abc].[def].[efg].hij.test.id " + "FROM [abc].[def].[efg].hij.test", + ) + def test_force_schema_quoted_name_w_dot_case_insensitive(self): metadata = MetaData() tbl = Table( diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index 347adbbd6f..6d74d19f70 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -468,48 +468,56 @@ class OwnerPlusDBTest(fixtures.TestBase): ) eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")]) - def test_owner_database_pairs(self): + @testing.combinations( + ("foo", None, "foo", "use foo"), + ("foo.bar", "foo", "bar", "use foo"), + ("Foo.Bar", "Foo", "Bar", "use [Foo]"), + ("[Foo.Bar]", None, "Foo.Bar", "use [Foo.Bar]"), + ("[Foo.Bar].[bat]", "Foo.Bar", "bat", "use [Foo.Bar]"), + ( + "[foo].]do something; select [foo", + "foo", + "do something; select foo", + "use foo", + ), + ( + "something; select [foo].bar", + "something; select foo", + "bar", + "use [something; select foo]", + ), + ( + "[abc].[def].[efg].[hij]", + "[abc].[def].[efg]", + "hij", + "use [abc].[def].[efg]", + ), + ("abc.def.efg.hij", "abc.def.efg", "hij", "use [abc.def.efg]"), + ) + def test_owner_database_pairs( + self, identifier, expected_schema, expected_owner, use_stmt + ): dialect = mssql.dialect() - for identifier, expected_schema, expected_owner, use_stmt in [ - ("foo", None, "foo", "use foo"), - ("foo.bar", "foo", "bar", "use foo"), - ("Foo.Bar", "Foo", "Bar", "use [Foo]"), - ("[Foo.Bar]", None, "Foo.Bar", "use [Foo].[Bar]"), - ("[Foo.Bar].[bat]", "Foo.Bar", "bat", "use [Foo].[Bar]"), - ( - "[foo].]do something; select [foo", - "foo", - "do something; select foo", - "use foo", - ), - ( - "something; select [foo].bar", - "something; select foo", - "bar", - "use [something; select foo]", - ), - ]: - schema, owner = base._owner_plus_db(dialect, identifier) + schema, owner = base._owner_plus_db(dialect, identifier) - eq_(owner, expected_owner) - eq_(schema, expected_schema) + eq_(owner, expected_owner) + eq_(schema, expected_schema) - mock_connection = mock.Mock( - dialect=dialect, - scalar=mock.Mock(return_value="Some ] Database"), + mock_connection = mock.Mock( + dialect=dialect, scalar=mock.Mock(return_value="Some Database"), + ) + mock_lambda = mock.Mock() + base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar") + if schema is None: + eq_(mock_connection.mock_calls, []) + else: + eq_( + mock_connection.mock_calls, + [ + mock.call.scalar("select db_name()"), + mock.call.execute(use_stmt), + mock.call.execute("use [Some Database]"), + ], ) - mock_lambda = mock.Mock() - base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar") - if schema is None: - eq_(mock_connection.mock_calls, []) - else: - eq_( - mock_connection.mock_calls, - [ - mock.call.scalar("select db_name()"), - mock.call.execute(use_stmt), - mock.call.execute("use [Some Database]"), - ], - ) - eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")]) + eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])