]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
require at least one dialect name for variant
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 26 Jun 2022 14:22:39 +0000 (10:22 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Sun, 26 Jun 2022 14:29:04 +0000 (16:29 +0200)
the call doesn't make sense otherwise

Fixes: #8179
Change-Id: I0e5dd584dc7090b536f9732cbfc6f3a5c8846dc5

lib/sqlalchemy/sql/type_api.py
test/sql/test_types.py

index d8f1e92c493423cb7aca8fb90bcb8356a69e1264..46bf151eaf8f362aa3b4320d828372496f711b9c 100644 (file)
@@ -679,16 +679,17 @@ class TypeEngine(Visitable, Generic[_T]):
 
         """
 
+        if not dialect_names:
+            raise exc.ArgumentError("At least one dialect name is required")
         for dialect_name in dialect_names:
             if dialect_name in self._variant_mapping:
                 raise exc.ArgumentError(
-                    "Dialect '%s' is already present in "
-                    "the mapping for this %r" % (dialect_name, self)
+                    f"Dialect {dialect_name!r} is already present in "
+                    f"the mapping for this {self!r}"
                 )
         new_type = self.copy()
-        if isinstance(type_, type):
-            type_ = type_()
-        elif type_._variant_mapping:
+        type_ = to_instance(type_)
+        if type_._variant_mapping:
             raise exc.ArgumentError(
                 "can't pass a type that already has variants as a "
                 "dialect-level type to with_variant()"
index 623688b83ea5fe8872e5db595cd294ef6364b7d2..a154666bb678bc6bcb9c6032e0c94b809d916bc7 100644 (file)
@@ -1515,10 +1515,16 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
         self.UTypeTwo = UTypeTwo
         self.UTypeThree = UTypeThree
         self.variant = self.UTypeOne().with_variant(
-            self.UTypeTwo(), "postgresql"
+            self.UTypeTwo(), "postgresql", "mssql"
         )
         self.composite = self.variant.with_variant(self.UTypeThree(), "mysql")
 
+    def test_one_dialect_is_req(self):
+        with expect_raises_message(
+            exc.ArgumentError, "At least one dialect name is required"
+        ):
+            String().with_variant(VARCHAR())
+
     def test_illegal_dupe(self):
         v = self.UTypeOne().with_variant(self.UTypeTwo(), "postgresql")
         assert_raises_message(
@@ -1547,6 +1553,9 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(
             self.variant, "UTYPETWO", dialect=dialects.postgresql.dialect()
         )
+        self.assert_compile(
+            self.variant, "UTYPETWO", dialect=dialects.mssql.dialect()
+        )
 
     def test_to_instance(self):
         self.assert_compile(