]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support multiple dotted sections in mssql schema names
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 1 Jun 2020 00:34:03 +0000 (20:34 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 1 Jun 2020 19:38:53 +0000 (15:38 -0400)
Refined the logic used by the SQL Server dialect to interpret multi-part
schema names that contain many dots, to not actually lose any dots if the
name does not have bracking or quoting used, and additionally to support a
"dbname" token that has many parts including that it may have multiple,
independently-bracketed sections.

This fix addresses #5364 to some degree but probably does not
resolve it fully.

References: #5364
Fixes: #5366
Change-Id: I460cd74ce443efb35fb63b6864f00c6d81422688
(cherry picked from commit 9aff0102813900fd7bc2120df5e1cfa169edb44f)

doc/build/changelog/unreleased_13/5366.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
test/dialect/mssql/test_compiler.py
test/dialect/mssql/test_reflection.py

diff --git a/doc/build/changelog/unreleased_13/5366.rst b/doc/build/changelog/unreleased_13/5366.rst
new file mode 100644 (file)
index 0000000..ff69439
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, mssql
+    :tickets: 5366, 5364
+
+    Refined the logic used by the SQL Server dialect to interpret multi-part
+    schema names that contain many dots, to not actually lose any dots if the
+    name does not have bracking or quoting used, and additionally to support a
+    "dbname" token that has many parts including that it may have multiple,
+    independently-bracketed sections.
+
+
index 496df4dc273af27b13ddc648d62f10ccab3a9c63..4f673fba14d9885add4ae2da95c4380e9a04921a 100644 (file)
@@ -2211,8 +2211,7 @@ def _switch_db(dbname, connection, fn, *arg, **kw):
         current_db = connection.scalar("select db_name()")
         if current_db != dbname:
             connection.execute(
-                "use %s"
-                % connection.dialect.identifier_preparer.quote_schema(dbname)
+                "use %s" % connection.dialect.identifier_preparer.quote(dbname)
             )
     try:
         return fn(*arg, **kw)
@@ -2220,9 +2219,7 @@ def _switch_db(dbname, connection, fn, *arg, **kw):
         if dbname and current_db != dbname:
             connection.execute(
                 "use %s"
-                % connection.dialect.identifier_preparer.quote_schema(
-                    current_db
-                )
+                % connection.dialect.identifier_preparer.quote(current_db)
             )
 
 
@@ -2235,33 +2232,62 @@ def _owner_plus_db(dialect, schema):
         return None, schema
 
 
+_memoized_schema = util.LRUCache()
+
+
 def _schema_elements(schema):
     if isinstance(schema, quoted_name) and schema.quote:
         return None, schema
 
+    if schema in _memoized_schema:
+        return _memoized_schema[schema]
+
+    # tests for this function are in:
+    # test/dialect/mssql/test_reflection.py ->
+    #           OwnerPlusDBTest.test_owner_database_pairs
+    # test/dialect/mssql/test_compiler.py -> test_force_schema_*
+    # test/dialect/mssql/test_compiler.py -> test_schema_many_tokens_*
+    #
+
     push = []
     symbol = ""
     bracket = False
+    has_brackets = False
     for token in re.split(r"(\[|\]|\.)", schema):
         if not token:
             continue
         if token == "[":
             bracket = True
+            has_brackets = True
         elif token == "]":
             bracket = False
         elif not bracket and token == ".":
-            push.append(symbol)
+            if has_brackets:
+                push.append("[%s]" % symbol)
+            else:
+                push.append(symbol)
             symbol = ""
+            has_brackets = False
         else:
             symbol += token
     if symbol:
         push.append(symbol)
     if len(push) > 1:
-        return push[0], "".join(push[1:])
+        dbname, owner = ".".join(push[0:-1]), push[-1]
+
+        # test for internal brackets
+        if re.match(r".*\].*\[.*", dbname[1:-1]):
+            dbname = quoted_name(dbname, quote=False)
+        else:
+            dbname = dbname.lstrip("[").rstrip("]")
+
     elif len(push):
-        return None, push[0]
+        dbname, owner = None, push[0]
     else:
-        return None, None
+        dbname, owner = None, None
+
+    _memoized_schema[schema] = dbname, owner
+    return dbname, owner
 
 
 class MSDialect(default.DefaultDialect):
index 1865fd589b1ef29e0408658d3a1fa3d872ab8cc5..3656ed9b302adb3deeb4e3fbc64e995f9d267f7a 100644 (file)
@@ -22,6 +22,7 @@ from sqlalchemy import union
 from sqlalchemy import UniqueConstraint
 from sqlalchemy import update
 from sqlalchemy.dialects import mssql
+from sqlalchemy.dialects.mssql import base as mssql_base
 from sqlalchemy.dialects.mssql import mxodbc
 from sqlalchemy.dialects.mssql.base import try_cast
 from sqlalchemy.sql import column
@@ -409,6 +410,42 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             checkpositional=("bar",),
         )
 
+    def test_schema_many_tokens_one(self):
+        metadata = MetaData()
+        tbl = Table(
+            "test",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            schema="abc.def.efg.hij",
+        )
+
+        # for now, we don't really know what the above means, at least
+        # don't lose the dot
+        self.assert_compile(
+            select([tbl]),
+            "SELECT [abc.def.efg].hij.test.id FROM [abc.def.efg].hij.test",
+        )
+
+        dbname, owner = mssql_base._schema_elements("abc.def.efg.hij")
+        eq_(dbname, "abc.def.efg")
+        assert not isinstance(dbname, quoted_name)
+        eq_(owner, "hij")
+
+    def test_schema_many_tokens_two(self):
+        metadata = MetaData()
+        tbl = Table(
+            "test",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            schema="[abc].[def].[efg].[hij]",
+        )
+
+        self.assert_compile(
+            select([tbl]),
+            "SELECT [abc].[def].[efg].hij.test.id "
+            "FROM [abc].[def].[efg].hij.test",
+        )
+
     def test_force_schema_quoted_name_w_dot_case_insensitive(self):
         metadata = MetaData()
         tbl = Table(
index 347adbbd6f3f105d9a11cdb638b401c00c91c6db..6d74d19f70eaeaf81b4138e0d1ec452d45c0c44f 100644 (file)
@@ -468,48 +468,56 @@ class OwnerPlusDBTest(fixtures.TestBase):
         )
         eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])
 
-    def test_owner_database_pairs(self):
+    @testing.combinations(
+        ("foo", None, "foo", "use foo"),
+        ("foo.bar", "foo", "bar", "use foo"),
+        ("Foo.Bar", "Foo", "Bar", "use [Foo]"),
+        ("[Foo.Bar]", None, "Foo.Bar", "use [Foo.Bar]"),
+        ("[Foo.Bar].[bat]", "Foo.Bar", "bat", "use [Foo.Bar]"),
+        (
+            "[foo].]do something; select [foo",
+            "foo",
+            "do something; select foo",
+            "use foo",
+        ),
+        (
+            "something; select [foo].bar",
+            "something; select foo",
+            "bar",
+            "use [something; select foo]",
+        ),
+        (
+            "[abc].[def].[efg].[hij]",
+            "[abc].[def].[efg]",
+            "hij",
+            "use [abc].[def].[efg]",
+        ),
+        ("abc.def.efg.hij", "abc.def.efg", "hij", "use [abc.def.efg]"),
+    )
+    def test_owner_database_pairs(
+        self, identifier, expected_schema, expected_owner, use_stmt
+    ):
         dialect = mssql.dialect()
 
-        for identifier, expected_schema, expected_owner, use_stmt in [
-            ("foo", None, "foo", "use foo"),
-            ("foo.bar", "foo", "bar", "use foo"),
-            ("Foo.Bar", "Foo", "Bar", "use [Foo]"),
-            ("[Foo.Bar]", None, "Foo.Bar", "use [Foo].[Bar]"),
-            ("[Foo.Bar].[bat]", "Foo.Bar", "bat", "use [Foo].[Bar]"),
-            (
-                "[foo].]do something; select [foo",
-                "foo",
-                "do something; select foo",
-                "use foo",
-            ),
-            (
-                "something; select [foo].bar",
-                "something; select foo",
-                "bar",
-                "use [something; select foo]",
-            ),
-        ]:
-            schema, owner = base._owner_plus_db(dialect, identifier)
+        schema, owner = base._owner_plus_db(dialect, identifier)
 
-            eq_(owner, expected_owner)
-            eq_(schema, expected_schema)
+        eq_(owner, expected_owner)
+        eq_(schema, expected_schema)
 
-            mock_connection = mock.Mock(
-                dialect=dialect,
-                scalar=mock.Mock(return_value="Some ] Database"),
+        mock_connection = mock.Mock(
+            dialect=dialect, scalar=mock.Mock(return_value="Some  Database"),
+        )
+        mock_lambda = mock.Mock()
+        base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar")
+        if schema is None:
+            eq_(mock_connection.mock_calls, [])
+        else:
+            eq_(
+                mock_connection.mock_calls,
+                [
+                    mock.call.scalar("select db_name()"),
+                    mock.call.execute(use_stmt),
+                    mock.call.execute("use [Some  Database]"),
+                ],
             )
-            mock_lambda = mock.Mock()
-            base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar")
-            if schema is None:
-                eq_(mock_connection.mock_calls, [])
-            else:
-                eq_(
-                    mock_connection.mock_calls,
-                    [
-                        mock.call.scalar("select db_name()"),
-                        mock.call.execute(use_stmt),
-                        mock.call.execute("use [Some  Database]"),
-                    ],
-                )
-            eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])
+        eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])