From: Mike Bayer Date: Thu, 3 Oct 2019 15:18:06 +0000 (-0400) Subject: Apply quoting to SQL Server _switch_db X-Git-Tag: rel_1_3_9~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a0f47903711afc6ea438d8c6655f3f8011216afd;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Apply quoting to SQL Server _switch_db Added identifier quoting to the schema name applied to the "use" statement which is invoked when a SQL Server multipart schema name is used within a :class:`.Table` that is being reflected, as well as for :class:`.Inspector` methods such as :meth:`.Inspector.get_table_names`; this accommodates for special characters or spaces in the database name. Additionally, the "use" statement is not emitted if the current database matches the target owner database name being passed. Fixes: #4883 Change-Id: I84419730e94aac3a88d331ad8c24d10aabbc34af (cherry picked from commit 66a7befa0c549b92d42afbb5be2b45da13793250) --- diff --git a/doc/build/changelog/unreleased_13/4883.rst b/doc/build/changelog/unreleased_13/4883.rst new file mode 100644 index 0000000000..161dbf1464 --- /dev/null +++ b/doc/build/changelog/unreleased_13/4883.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, mssql + :tickets: 4883 + + Added identifier quoting to the schema name applied to the "use" statement + which is invoked when a SQL Server multipart schema name is used within a + :class:`.Table` that is being reflected, as well as for :class:`.Inspector` + methods such as :meth:`.Inspector.get_table_names`; this accommodates for + special characters or spaces in the database name. Additionally, the "use" + statement is not emitted if the current database matches the target owner + database name being passed. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 37e69d4341..3e770d6e15 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2162,12 +2162,21 @@ def _db_plus_owner(fn): def _switch_db(dbname, connection, fn, *arg, **kw): if dbname: current_db = connection.scalar("select db_name()") - connection.execute("use %s" % dbname) + if current_db != dbname: + connection.execute( + "use %s" + % connection.dialect.identifier_preparer.quote_schema(dbname) + ) try: return fn(*arg, **kw) finally: - if dbname: - connection.execute("use %s" % current_db) + if dbname and current_db != dbname: + connection.execute( + "use %s" + % connection.dialect.identifier_preparer.quote_schema( + current_db + ) + ) def _owner_plus_db(dialect, schema): diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index 0815a8f899..4fba61dfe2 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -20,7 +20,6 @@ from sqlalchemy import union from sqlalchemy import UniqueConstraint from sqlalchemy import update from sqlalchemy.dialects import mssql -from sqlalchemy.dialects.mssql import base from sqlalchemy.dialects.mssql import mxodbc from sqlalchemy.dialects.mssql.base import try_cast from sqlalchemy.sql import column @@ -480,21 +479,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): select([tbl]), "SELECT [Foo].dbo.test.id FROM [Foo].dbo.test" ) - def test_owner_database_pairs(self): - dialect = mssql.dialect() - - for identifier, expected_schema, expected_owner in [ - ("foo", None, "foo"), - ("foo.bar", "foo", "bar"), - ("Foo.Bar", "Foo", "Bar"), - ("[Foo.Bar]", None, "Foo.Bar"), - ("[Foo.Bar].[bat]", "Foo.Bar", "bat"), - ]: - schema, owner = base._owner_plus_db(dialect, identifier) - - eq_(owner, expected_owner) - eq_(schema, expected_schema) - def test_delete_schema(self): metadata = MetaData() tbl = Table( diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index 8393a7b482..789fe6526a 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -418,3 +418,86 @@ class ReflectHugeViewTest(fixtures.TestBase): inspector = Inspector.from_engine(testing.db) view_def = inspector.get_view_definition("huge_named_view") eq_(view_def, self.view_str) + + +class OwnerPlusDBTest(fixtures.TestBase): + def test_owner_database_pairs_dont_use_for_same_db(self): + dialect = mssql.dialect() + + identifier = "my_db.some_schema" + schema, owner = base._owner_plus_db(dialect, identifier) + + mock_connection = mock.Mock( + dialect=dialect, scalar=mock.Mock(return_value="my_db") + ) + mock_lambda = mock.Mock() + base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar") + eq_(mock_connection.mock_calls, [mock.call.scalar("select db_name()")]) + eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")]) + + def test_owner_database_pairs_switch_for_different_db(self): + dialect = mssql.dialect() + + identifier = "my_other_db.some_schema" + schema, owner = base._owner_plus_db(dialect, identifier) + + mock_connection = mock.Mock( + dialect=dialect, scalar=mock.Mock(return_value="my_db") + ) + mock_lambda = mock.Mock() + base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar") + eq_( + mock_connection.mock_calls, + [ + mock.call.scalar("select db_name()"), + mock.call.execute("use my_other_db"), + mock.call.execute("use my_db"), + ], + ) + eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")]) + + def test_owner_database_pairs(self): + dialect = mssql.dialect() + + for identifier, expected_schema, expected_owner, use_stmt in [ + ("foo", None, "foo", "use foo"), + ("foo.bar", "foo", "bar", "use foo"), + ("Foo.Bar", "Foo", "Bar", "use [Foo]"), + ("[Foo.Bar]", None, "Foo.Bar", "use [Foo].[Bar]"), + ("[Foo.Bar].[bat]", "Foo.Bar", "bat", "use [Foo].[Bar]"), + ( + "[foo].]do something; select [foo", + "foo", + "do something; select foo", + "use foo", + ), + ( + "something; select [foo].bar", + "something; select foo", + "bar", + "use [something; select foo]", + ), + ]: + schema, owner = base._owner_plus_db(dialect, identifier) + + eq_(owner, expected_owner) + eq_(schema, expected_schema) + + mock_connection = mock.Mock( + dialect=dialect, + scalar=mock.Mock(return_value="Some ] Database"), + ) + mock_lambda = mock.Mock() + base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar") + if schema is None: + eq_(mock_connection.mock_calls, []) + else: + eq_( + mock_connection.mock_calls, + [ + mock.call.scalar("select db_name()"), + mock.call.execute(use_stmt), + mock.call.execute("use [Some Database]"), + ], + ) + eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])