From: Mike Bayer Date: Thu, 3 Oct 2019 15:18:06 +0000 (-0400) Subject: Apply quoting to SQL Server _switch_db X-Git-Tag: rel_1_4_0b1~705^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=66a7befa0c549b92d42afbb5be2b45da13793250;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Apply quoting to SQL Server _switch_db Added identifier quoting to the schema name applied to the "use" statement which is invoked when a SQL Server multipart schema name is used within a :class:`.Table` that is being reflected, as well as for :class:`.Inspector` methods such as :meth:`.Inspector.get_table_names`; this accommodates for special characters or spaces in the database name. Additionally, the "use" statement is not emitted if the current database matches the target owner database name being passed. Fixes: #4883 Change-Id: I84419730e94aac3a88d331ad8c24d10aabbc34af --- diff --git a/doc/build/changelog/unreleased_13/4883.rst b/doc/build/changelog/unreleased_13/4883.rst new file mode 100644 index 0000000000..161dbf1464 --- /dev/null +++ b/doc/build/changelog/unreleased_13/4883.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, mssql + :tickets: 4883 + + Added identifier quoting to the schema name applied to the "use" statement + which is invoked when a SQL Server multipart schema name is used within a + :class:`.Table` that is being reflected, as well as for :class:`.Inspector` + methods such as :meth:`.Inspector.get_table_names`; this accommodates for + special characters or spaces in the database name. Additionally, the "use" + statement is not emitted if the current database matches the target owner + database name being passed. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 54f0043c4b..d4d303d5de 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2174,12 +2174,21 @@ def _db_plus_owner(fn): def _switch_db(dbname, connection, fn, *arg, **kw): if dbname: current_db = connection.scalar("select db_name()") - connection.execute("use %s" % dbname) + if current_db != dbname: + connection.execute( + "use %s" + % connection.dialect.identifier_preparer.quote_schema(dbname) + ) try: return fn(*arg, **kw) finally: - if dbname: - connection.execute("use %s" % current_db) + if dbname and current_db != dbname: + connection.execute( + "use %s" + % connection.dialect.identifier_preparer.quote_schema( + current_db + ) + ) def _owner_plus_db(dialect, schema): diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index 4f656a36c9..00a8a08fcc 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -20,7 +20,6 @@ from sqlalchemy import union from sqlalchemy import UniqueConstraint from sqlalchemy import update from sqlalchemy.dialects import mssql -from sqlalchemy.dialects.mssql import base from sqlalchemy.dialects.mssql import mxodbc from sqlalchemy.dialects.mssql.base import try_cast from sqlalchemy.sql import column @@ -480,21 +479,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): select([tbl]), "SELECT [Foo].dbo.test.id FROM [Foo].dbo.test" ) - def test_owner_database_pairs(self): - dialect = mssql.dialect() - - for identifier, expected_schema, expected_owner in [ - ("foo", None, "foo"), - ("foo.bar", "foo", "bar"), - ("Foo.Bar", "Foo", "Bar"), - ("[Foo.Bar]", None, "Foo.Bar"), - ("[Foo.Bar].[bat]", "Foo.Bar", "bat"), - ]: - schema, owner = base._owner_plus_db(dialect, identifier) - - eq_(owner, expected_owner) - eq_(schema, expected_schema) - def test_delete_schema(self): metadata = MetaData() tbl = Table( diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index 937f2bb2aa..24c4a64558 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -418,3 +418,86 @@ class ReflectHugeViewTest(fixtures.TestBase): inspector = Inspector.from_engine(testing.db) view_def = inspector.get_view_definition("huge_named_view") eq_(view_def, self.view_str) + + +class OwnerPlusDBTest(fixtures.TestBase): + def test_owner_database_pairs_dont_use_for_same_db(self): + dialect = mssql.dialect() + + identifier = "my_db.some_schema" + schema, owner = base._owner_plus_db(dialect, identifier) + + mock_connection = mock.Mock( + dialect=dialect, scalar=mock.Mock(return_value="my_db") + ) + mock_lambda = mock.Mock() + base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar") + eq_(mock_connection.mock_calls, [mock.call.scalar("select db_name()")]) + eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")]) + + def test_owner_database_pairs_switch_for_different_db(self): + dialect = mssql.dialect() + + identifier = "my_other_db.some_schema" + schema, owner = base._owner_plus_db(dialect, identifier) + + mock_connection = mock.Mock( + dialect=dialect, scalar=mock.Mock(return_value="my_db") + ) + mock_lambda = mock.Mock() + base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar") + eq_( + mock_connection.mock_calls, + [ + mock.call.scalar("select db_name()"), + mock.call.execute("use my_other_db"), + mock.call.execute("use my_db"), + ], + ) + eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")]) + + def test_owner_database_pairs(self): + dialect = mssql.dialect() + + for identifier, expected_schema, expected_owner, use_stmt in [ + ("foo", None, "foo", "use foo"), + ("foo.bar", "foo", "bar", "use foo"), + ("Foo.Bar", "Foo", "Bar", "use [Foo]"), + ("[Foo.Bar]", None, "Foo.Bar", "use [Foo].[Bar]"), + ("[Foo.Bar].[bat]", "Foo.Bar", "bat", "use [Foo].[Bar]"), + ( + "[foo].]do something; select [foo", + "foo", + "do something; select foo", + "use foo", + ), + ( + "something; select [foo].bar", + "something; select foo", + "bar", + "use [something; select foo]", + ), + ]: + schema, owner = base._owner_plus_db(dialect, identifier) + + eq_(owner, expected_owner) + eq_(schema, expected_schema) + + mock_connection = mock.Mock( + dialect=dialect, + scalar=mock.Mock(return_value="Some ] Database"), + ) + mock_lambda = mock.Mock() + base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar") + if schema is None: + eq_(mock_connection.mock_calls, []) + else: + eq_( + mock_connection.mock_calls, + [ + mock.call.scalar("select db_name()"), + mock.call.execute(use_stmt), + mock.call.execute("use [Some Database]"), + ], + ) + eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])