]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Apply quoting to SQL Server _switch_db
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 Oct 2019 15:18:06 +0000 (11:18 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 Oct 2019 15:34:34 +0000 (11:34 -0400)
Added identifier quoting to the schema name applied to the "use" statement
which is invoked when a SQL Server multipart schema name is used within  a
:class:`.Table` that is being reflected, as well as for :class:`.Inspector`
methods such as :meth:`.Inspector.get_table_names`; this accommodates for
special characters or spaces in the database name.  Additionally, the "use"
statement is not emitted if the current database matches the target owner
database name being passed.

Fixes: #4883
Change-Id: I84419730e94aac3a88d331ad8c24d10aabbc34af

doc/build/changelog/unreleased_13/4883.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
test/dialect/mssql/test_compiler.py
test/dialect/mssql/test_reflection.py

diff --git a/doc/build/changelog/unreleased_13/4883.rst b/doc/build/changelog/unreleased_13/4883.rst
new file mode 100644 (file)
index 0000000..161dbf1
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, mssql
+    :tickets: 4883
+
+    Added identifier quoting to the schema name applied to the "use" statement
+    which is invoked when a SQL Server multipart schema name is used within  a
+    :class:`.Table` that is being reflected, as well as for :class:`.Inspector`
+    methods such as :meth:`.Inspector.get_table_names`; this accommodates for
+    special characters or spaces in the database name.  Additionally, the "use"
+    statement is not emitted if the current database matches the target owner
+    database name being passed.
index 54f0043c4bcb9d046c3ea013906ec362576ac915..d4d303d5de84adea550d0ab3f79fd354816907ef 100644 (file)
@@ -2174,12 +2174,21 @@ def _db_plus_owner(fn):
 def _switch_db(dbname, connection, fn, *arg, **kw):
     if dbname:
         current_db = connection.scalar("select db_name()")
-        connection.execute("use %s" % dbname)
+        if current_db != dbname:
+            connection.execute(
+                "use %s"
+                % connection.dialect.identifier_preparer.quote_schema(dbname)
+            )
     try:
         return fn(*arg, **kw)
     finally:
-        if dbname:
-            connection.execute("use %s" % current_db)
+        if dbname and current_db != dbname:
+            connection.execute(
+                "use %s"
+                % connection.dialect.identifier_preparer.quote_schema(
+                    current_db
+                )
+            )
 
 
 def _owner_plus_db(dialect, schema):
index 4f656a36c918a66a2f5be19b571859d6849de37b..00a8a08fcc0583e579411d03ee699fccd8fd4640 100644 (file)
@@ -20,7 +20,6 @@ from sqlalchemy import union
 from sqlalchemy import UniqueConstraint
 from sqlalchemy import update
 from sqlalchemy.dialects import mssql
-from sqlalchemy.dialects.mssql import base
 from sqlalchemy.dialects.mssql import mxodbc
 from sqlalchemy.dialects.mssql.base import try_cast
 from sqlalchemy.sql import column
@@ -480,21 +479,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             select([tbl]), "SELECT [Foo].dbo.test.id FROM [Foo].dbo.test"
         )
 
-    def test_owner_database_pairs(self):
-        dialect = mssql.dialect()
-
-        for identifier, expected_schema, expected_owner in [
-            ("foo", None, "foo"),
-            ("foo.bar", "foo", "bar"),
-            ("Foo.Bar", "Foo", "Bar"),
-            ("[Foo.Bar]", None, "Foo.Bar"),
-            ("[Foo.Bar].[bat]", "Foo.Bar", "bat"),
-        ]:
-            schema, owner = base._owner_plus_db(dialect, identifier)
-
-            eq_(owner, expected_owner)
-            eq_(schema, expected_schema)
-
     def test_delete_schema(self):
         metadata = MetaData()
         tbl = Table(
index 937f2bb2aa13e3384b20f1fafbdebd35958740ab..24c4a645584df3bb3a7cb1b02dad75ed7391a39e 100644 (file)
@@ -418,3 +418,86 @@ class ReflectHugeViewTest(fixtures.TestBase):
         inspector = Inspector.from_engine(testing.db)
         view_def = inspector.get_view_definition("huge_named_view")
         eq_(view_def, self.view_str)
+
+
+class OwnerPlusDBTest(fixtures.TestBase):
+    def test_owner_database_pairs_dont_use_for_same_db(self):
+        dialect = mssql.dialect()
+
+        identifier = "my_db.some_schema"
+        schema, owner = base._owner_plus_db(dialect, identifier)
+
+        mock_connection = mock.Mock(
+            dialect=dialect, scalar=mock.Mock(return_value="my_db")
+        )
+        mock_lambda = mock.Mock()
+        base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar")
+        eq_(mock_connection.mock_calls, [mock.call.scalar("select db_name()")])
+        eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])
+
+    def test_owner_database_pairs_switch_for_different_db(self):
+        dialect = mssql.dialect()
+
+        identifier = "my_other_db.some_schema"
+        schema, owner = base._owner_plus_db(dialect, identifier)
+
+        mock_connection = mock.Mock(
+            dialect=dialect, scalar=mock.Mock(return_value="my_db")
+        )
+        mock_lambda = mock.Mock()
+        base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar")
+        eq_(
+            mock_connection.mock_calls,
+            [
+                mock.call.scalar("select db_name()"),
+                mock.call.execute("use my_other_db"),
+                mock.call.execute("use my_db"),
+            ],
+        )
+        eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])
+
+    def test_owner_database_pairs(self):
+        dialect = mssql.dialect()
+
+        for identifier, expected_schema, expected_owner, use_stmt in [
+            ("foo", None, "foo", "use foo"),
+            ("foo.bar", "foo", "bar", "use foo"),
+            ("Foo.Bar", "Foo", "Bar", "use [Foo]"),
+            ("[Foo.Bar]", None, "Foo.Bar", "use [Foo].[Bar]"),
+            ("[Foo.Bar].[bat]", "Foo.Bar", "bat", "use [Foo].[Bar]"),
+            (
+                "[foo].]do something; select [foo",
+                "foo",
+                "do something; select foo",
+                "use foo",
+            ),
+            (
+                "something; select [foo].bar",
+                "something; select foo",
+                "bar",
+                "use [something; select foo]",
+            ),
+        ]:
+            schema, owner = base._owner_plus_db(dialect, identifier)
+
+            eq_(owner, expected_owner)
+            eq_(schema, expected_schema)
+
+            mock_connection = mock.Mock(
+                dialect=dialect,
+                scalar=mock.Mock(return_value="Some ] Database"),
+            )
+            mock_lambda = mock.Mock()
+            base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar")
+            if schema is None:
+                eq_(mock_connection.mock_calls, [])
+            else:
+                eq_(
+                    mock_connection.mock_calls,
+                    [
+                        mock.call.scalar("select db_name()"),
+                        mock.call.execute(use_stmt),
+                        mock.call.execute("use [Some  Database]"),
+                    ],
+                )
+            eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])