]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support multiple dotted sections in mssql schema names
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 1 Jun 2020 00:34:03 +0000 (20:34 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 1 Jun 2020 19:37:48 +0000 (15:37 -0400)
Refined the logic used by the SQL Server dialect to interpret multi-part
schema names that contain many dots, to not actually lose any dots if the
name does not have bracking or quoting used, and additionally to support a
"dbname" token that has many parts including that it may have multiple,
independently-bracketed sections.

This fix addresses #5364 to some degree but probably does not
resolve it fully.

References: #5364
Fixes: #5366
Change-Id: I460cd74ce443efb35fb63b6864f00c6d81422688

doc/build/changelog/unreleased_13/5366.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
test/dialect/mssql/test_compiler.py
test/dialect/mssql/test_reflection.py

diff --git a/doc/build/changelog/unreleased_13/5366.rst b/doc/build/changelog/unreleased_13/5366.rst
new file mode 100644 (file)
index 0000000..ff69439
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, mssql
+    :tickets: 5366, 5364
+
+    Refined the logic used by the SQL Server dialect to interpret multi-part
+    schema names that contain many dots, to not actually lose any dots if the
+    name does not have bracking or quoting used, and additionally to support a
+    "dbname" token that has many parts including that it may have multiple,
+    independently-bracketed sections.
+
+
index 5e07045978435e612d1e866b8a5ffb068d323c62..bbf44906c9f52f40d4015147e2088261341e0547 100644 (file)
@@ -2287,8 +2287,7 @@ def _switch_db(dbname, connection, fn, *arg, **kw):
         current_db = connection.exec_driver_sql("select db_name()").scalar()
         if current_db != dbname:
             connection.exec_driver_sql(
-                "use %s"
-                % connection.dialect.identifier_preparer.quote_schema(dbname)
+                "use %s" % connection.dialect.identifier_preparer.quote(dbname)
             )
     try:
         return fn(*arg, **kw)
@@ -2296,9 +2295,7 @@ def _switch_db(dbname, connection, fn, *arg, **kw):
         if dbname and current_db != dbname:
             connection.exec_driver_sql(
                 "use %s"
-                % connection.dialect.identifier_preparer.quote_schema(
-                    current_db
-                )
+                % connection.dialect.identifier_preparer.quote(current_db)
             )
 
 
@@ -2311,33 +2308,62 @@ def _owner_plus_db(dialect, schema):
         return None, schema
 
 
+_memoized_schema = util.LRUCache()
+
+
 def _schema_elements(schema):
     if isinstance(schema, quoted_name) and schema.quote:
         return None, schema
 
+    if schema in _memoized_schema:
+        return _memoized_schema[schema]
+
+    # tests for this function are in:
+    # test/dialect/mssql/test_reflection.py ->
+    #           OwnerPlusDBTest.test_owner_database_pairs
+    # test/dialect/mssql/test_compiler.py -> test_force_schema_*
+    # test/dialect/mssql/test_compiler.py -> test_schema_many_tokens_*
+    #
+
     push = []
     symbol = ""
     bracket = False
+    has_brackets = False
     for token in re.split(r"(\[|\]|\.)", schema):
         if not token:
             continue
         if token == "[":
             bracket = True
+            has_brackets = True
         elif token == "]":
             bracket = False
         elif not bracket and token == ".":
-            push.append(symbol)
+            if has_brackets:
+                push.append("[%s]" % symbol)
+            else:
+                push.append(symbol)
             symbol = ""
+            has_brackets = False
         else:
             symbol += token
     if symbol:
         push.append(symbol)
     if len(push) > 1:
-        return push[0], "".join(push[1:])
+        dbname, owner = ".".join(push[0:-1]), push[-1]
+
+        # test for internal brackets
+        if re.match(r".*\].*\[.*", dbname[1:-1]):
+            dbname = quoted_name(dbname, quote=False)
+        else:
+            dbname = dbname.lstrip("[").rstrip("]")
+
     elif len(push):
-        return None, push[0]
+        dbname, owner = None, push[0]
     else:
-        return None, None
+        dbname, owner = None, None
+
+    _memoized_schema[schema] = dbname, owner
+    return dbname, owner
 
 
 class MSDialect(default.DefaultDialect):
index b7a06c8e3eb4666452c7626c923f7a62da941343..25af3240ecffd167d7c3e8df7cecf1460a15c955 100644 (file)
@@ -22,6 +22,7 @@ from sqlalchemy import union
 from sqlalchemy import UniqueConstraint
 from sqlalchemy import update
 from sqlalchemy.dialects import mssql
+from sqlalchemy.dialects.mssql import base as mssql_base
 from sqlalchemy.dialects.mssql import mxodbc
 from sqlalchemy.dialects.mssql.base import try_cast
 from sqlalchemy.sql import column
@@ -525,6 +526,42 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             checkpositional=("bar",),
         )
 
+    def test_schema_many_tokens_one(self):
+        metadata = MetaData()
+        tbl = Table(
+            "test",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            schema="abc.def.efg.hij",
+        )
+
+        # for now, we don't really know what the above means, at least
+        # don't lose the dot
+        self.assert_compile(
+            select([tbl]),
+            "SELECT [abc.def.efg].hij.test.id FROM [abc.def.efg].hij.test",
+        )
+
+        dbname, owner = mssql_base._schema_elements("abc.def.efg.hij")
+        eq_(dbname, "abc.def.efg")
+        assert not isinstance(dbname, quoted_name)
+        eq_(owner, "hij")
+
+    def test_schema_many_tokens_two(self):
+        metadata = MetaData()
+        tbl = Table(
+            "test",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            schema="[abc].[def].[efg].[hij]",
+        )
+
+        self.assert_compile(
+            select([tbl]),
+            "SELECT [abc].[def].[efg].hij.test.id "
+            "FROM [abc].[def].[efg].hij.test",
+        )
+
     def test_force_schema_quoted_name_w_dot_case_insensitive(self):
         metadata = MetaData()
         tbl = Table(
index 5328513e4cbf661b04a1afc7e5b7b95b8c884a9f..176d3d2ecbaa79daac0189c2988889258898869d 100644 (file)
@@ -484,56 +484,65 @@ class OwnerPlusDBTest(fixtures.TestBase):
         )
         eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])
 
-    def test_owner_database_pairs(self):
+    @testing.combinations(
+        ("foo", None, "foo", "use foo"),
+        ("foo.bar", "foo", "bar", "use foo"),
+        ("Foo.Bar", "Foo", "Bar", "use [Foo]"),
+        ("[Foo.Bar]", None, "Foo.Bar", "use [Foo.Bar]"),
+        ("[Foo.Bar].[bat]", "Foo.Bar", "bat", "use [Foo.Bar]"),
+        (
+            "[foo].]do something; select [foo",
+            "foo",
+            "do something; select foo",
+            "use foo",
+        ),
+        (
+            "something; select [foo].bar",
+            "something; select foo",
+            "bar",
+            "use [something; select foo]",
+        ),
+        (
+            "[abc].[def].[efg].[hij]",
+            "[abc].[def].[efg]",
+            "hij",
+            "use [abc].[def].[efg]",
+        ),
+        ("abc.def.efg.hij", "abc.def.efg", "hij", "use [abc.def.efg]"),
+    )
+    def test_owner_database_pairs(
+        self, identifier, expected_schema, expected_owner, use_stmt
+    ):
         dialect = mssql.dialect()
 
-        for identifier, expected_schema, expected_owner, use_stmt in [
-            ("foo", None, "foo", "use foo"),
-            ("foo.bar", "foo", "bar", "use foo"),
-            ("Foo.Bar", "Foo", "Bar", "use [Foo]"),
-            ("[Foo.Bar]", None, "Foo.Bar", "use [Foo].[Bar]"),
-            ("[Foo.Bar].[bat]", "Foo.Bar", "bat", "use [Foo].[Bar]"),
-            (
-                "[foo].]do something; select [foo",
-                "foo",
-                "do something; select foo",
-                "use foo",
-            ),
-            (
-                "something; select [foo].bar",
-                "something; select foo",
-                "bar",
-                "use [something; select foo]",
+        schema, owner = base._owner_plus_db(dialect, identifier)
+
+        eq_(owner, expected_owner)
+        eq_(schema, expected_schema)
+
+        mock_connection = mock.Mock(
+            dialect=dialect,
+            exec_driver_sql=mock.Mock(
+                return_value=mock.Mock(
+                    scalar=mock.Mock(return_value="Some Database")
+                )
             ),
-        ]:
-            schema, owner = base._owner_plus_db(dialect, identifier)
-
-            eq_(owner, expected_owner)
-            eq_(schema, expected_schema)
-
-            mock_connection = mock.Mock(
-                dialect=dialect,
-                exec_driver_sql=mock.Mock(
-                    return_value=mock.Mock(
-                        scalar=mock.Mock(return_value="Some ] Database")
-                    )
-                ),
+        )
+        mock_lambda = mock.Mock()
+        base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar")
+        if schema is None:
+            eq_(mock_connection.mock_calls, [])
+        else:
+            eq_(
+                mock_connection.mock_calls,
+                [
+                    mock.call.exec_driver_sql("select db_name()"),
+                    mock.call.exec_driver_sql(use_stmt),
+                    mock.call.exec_driver_sql("use [Some Database]"),
+                ],
             )
-            mock_lambda = mock.Mock()
-            base._switch_db(schema, mock_connection, mock_lambda, "x", y="bar")
-            if schema is None:
-                eq_(mock_connection.mock_calls, [])
-            else:
-                eq_(
-                    mock_connection.mock_calls,
-                    [
-                        mock.call.exec_driver_sql("select db_name()"),
-                        mock.call.exec_driver_sql(use_stmt),
-                        mock.call.exec_driver_sql("use [Some  Database]"),
-                    ],
-                )
-                eq_(
-                    mock_connection.exec_driver_sql.return_value.mock_calls,
-                    [mock.call.scalar()],
-                )
-            eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])
+            eq_(
+                mock_connection.exec_driver_sql.return_value.mock_calls,
+                [mock.call.scalar()],
+            )
+        eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")])