]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Apply quoting to SQL Server _switch_db
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 Oct 2019 15:18:06 +0000 (11:18 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 Oct 2019 15:35:14 +0000 (11:35 -0400)
Added identifier quoting to the schema name applied to the "use" statement
which is invoked when a SQL Server multipart schema name is used within  a
:class:`.Table` that is being reflected, as well as for :class:`.Inspector`
methods such as :meth:`.Inspector.get_table_names`; this accommodates for
special characters or spaces in the database name.  Additionally, the "use"
statement is not emitted if the current database matches the target owner
database name being passed.

Fixes: #4883
Change-Id: I84419730e94aac3a88d331ad8c24d10aabbc34af
(cherry picked from commit 66a7befa0c549b92d42afbb5be2b45da13793250)

doc/build/changelog/unreleased_13/4883.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
test/dialect/mssql/test_compiler.py
test/dialect/mssql/test_reflection.py

diff --git a/doc/build/changelog/unreleased_13/4883.rst b/doc/build/changelog/unreleased_13/4883.rst
new file mode 100644 (file)
index 0000000..161dbf1
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, mssql
+    :tickets: 4883
+
+    Added identifier quoting to the schema name applied to the "use" statement
+    which is invoked when a SQL Server multipart schema name is used within  a
+    :class:`.Table` that is being reflected, as well as for :class:`.Inspector`
+    methods such as :meth:`.Inspector.get_table_names`; this accommodates for
+    special characters or spaces in the database name.  Additionally, the "use"
+    statement is not emitted if the current database matches the target owner
+    database name being passed.
index 37e69d4341ea913340985f4e41680283ed2c719b..3e770d6e15036191d225bc6e9cc6cfdc3e43c5b4 100644 (file)
@@ -2162,12 +2162,21 @@ def _db_plus_owner(fn):
 def _switch_db(dbname, connection, fn, *arg, **kw):
     if dbname:
         current_db = connection.scalar("select db_name()")
-        connection.execute("use %s" % dbname)
+        if current_db != dbname:
+            connection.execute(
+                "use %s"
+                % connection.dialect.identifier_preparer.quote_schema(dbname)
+            )
     try:
         return fn(*arg, **kw)
     finally:
-        if dbname:
-            connection.execute("use %s" % current_db)
+        if dbname and current_db != dbname:
+            connection.execute(
+                "use %s"
+                % connection.dialect.identifier_preparer.quote_schema(
+                    current_db
+                )
+            )
 
 
 def _owner_plus_db(dialect, schema):
index 0815a8f899b5b40df67f56a400251872f0416b75..4fba61dfe26d79a4f55cf712d05cc651d3fbfedc 100644 (file)
@@ -20,7 +20,6 @@ from sqlalchemy import union
 from sqlalchemy import UniqueConstraint
 from sqlalchemy import update
 from sqlalchemy.dialects import mssql
-from sqlalchemy.dialects.mssql import base
 from sqlalchemy.dialects.mssql import mxodbc
 from sqlalchemy.dialects.mssql.base import try_cast
 from sqlalchemy.sql import column
@@ -480,21 +479,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             select([tbl]), "SELECT [Foo].dbo.test.id FROM [Foo].dbo.test"
         )
 
-    def test_owner_database_pairs(self):
-        dialect = mssql.dialect()
-
-        for identifier, expected_schema, expected_owner in [
-            ("foo", None, "foo"),
-            ("foo.bar", "foo", "bar"),
-            ("Foo.Bar", "Foo", "Bar"),
-            ("[Foo.Bar]", None, "Foo.Bar"),
-            ("[Foo.Bar].[bat]", "Foo.Bar", "bat"),
-        ]:
-            schema, owner = base._owner_plus_db(dialect, identifier)
-
-            eq_(owner, expected_owner)
-            eq_(schema, expected_schema)
-
     def test_delete_schema(self):
         metadata = MetaData()
         tbl = Table(
index 8393a7b482f9a5a50fccb9d10a3f0582f5b73594..789fe6526a3c3707a86a037c214337eb7ee64652 100644 (file)
@@ -418,3 +418,86 @@ class ReflectHugeViewTest(fixtures.TestBase):
         inspector = Inspector.from_engine(testing.db)
         view_def = inspector.get_view_definition("huge_named_view")
         eq_(view_def, self.view_str)
+
+
+class OwnerPlusDBTest(fixtures.TestBase):
+    def test_owner_database_pairs_dont_use_for_same_db(self):
+        dialect = mssql.dialect()
+
+        identifier = "my_db.some_schema"
+        schema, owner = base._owner_plus_db(dialect, identifier)
+
+        mock_connection = mock.Mock(
+            dialect=dialect, scalar=mock.Mock(return_value="my_db")
+        )
+        mock_lambda = mock.Mock()
+        base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar")
+        eq_(mock_connection.mock_calls, [mock.call.scalar("select db_name()")])
+        eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])
+
+    def test_owner_database_pairs_switch_for_different_db(self):
+        dialect = mssql.dialect()
+
+        identifier = "my_other_db.some_schema"
+        schema, owner = base._owner_plus_db(dialect, identifier)
+
+        mock_connection = mock.Mock(
+            dialect=dialect, scalar=mock.Mock(return_value="my_db")
+        )
+        mock_lambda = mock.Mock()
+        base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar")
+        eq_(
+            mock_connection.mock_calls,
+            [
+                mock.call.scalar("select db_name()"),
+                mock.call.execute("use my_other_db"),
+                mock.call.execute("use my_db"),
+            ],
+        )
+        eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])
+
+    def test_owner_database_pairs(self):
+        dialect = mssql.dialect()
+
+        for identifier, expected_schema, expected_owner, use_stmt in [
+            ("foo", None, "foo", "use foo"),
+            ("foo.bar", "foo", "bar", "use foo"),
+            ("Foo.Bar", "Foo", "Bar", "use [Foo]"),
+            ("[Foo.Bar]", None, "Foo.Bar", "use [Foo].[Bar]"),
+            ("[Foo.Bar].[bat]", "Foo.Bar", "bat", "use [Foo].[Bar]"),
+            (
+                "[foo].]do something; select [foo",
+                "foo",
+                "do something; select foo",
+                "use foo",
+            ),
+            (
+                "something; select [foo].bar",
+                "something; select foo",
+                "bar",
+                "use [something; select foo]",
+            ),
+        ]:
+            schema, owner = base._owner_plus_db(dialect, identifier)
+
+            eq_(owner, expected_owner)
+            eq_(schema, expected_schema)
+
+            mock_connection = mock.Mock(
+                dialect=dialect,
+                scalar=mock.Mock(return_value="Some ] Database"),
+            )
+            mock_lambda = mock.Mock()
+            base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar")
+            if schema is None:
+                eq_(mock_connection.mock_calls, [])
+            else:
+                eq_(
+                    mock_connection.mock_calls,
+                    [
+                        mock.call.scalar("select db_name()"),
+                        mock.call.execute(use_stmt),
+                        mock.call.execute("use [Some  Database]"),
+                    ],
+                )
+            eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])