From: Mike Bayer Date: Sun, 26 Jun 2022 14:22:39 +0000 (-0400) Subject: require at least one dialect name for variant X-Git-Tag: rel_2_0_0b1~210^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=815de6c3438ccba25b163eae2c34c5df7d82bf4d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git require at least one dialect name for variant the call doesn't make sense otherwise Fixes: #8179 Change-Id: I0e5dd584dc7090b536f9732cbfc6f3a5c8846dc5 --- diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index d8f1e92c49..46bf151eaf 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -679,16 +679,17 @@ class TypeEngine(Visitable, Generic[_T]): """ + if not dialect_names: + raise exc.ArgumentError("At least one dialect name is required") for dialect_name in dialect_names: if dialect_name in self._variant_mapping: raise exc.ArgumentError( - "Dialect '%s' is already present in " - "the mapping for this %r" % (dialect_name, self) + f"Dialect {dialect_name!r} is already present in " + f"the mapping for this {self!r}" ) new_type = self.copy() - if isinstance(type_, type): - type_ = type_() - elif type_._variant_mapping: + type_ = to_instance(type_) + if type_._variant_mapping: raise exc.ArgumentError( "can't pass a type that already has variants as a " "dialect-level type to with_variant()" diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 623688b83e..a154666bb6 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1515,10 +1515,16 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL): self.UTypeTwo = UTypeTwo self.UTypeThree = UTypeThree self.variant = self.UTypeOne().with_variant( - self.UTypeTwo(), "postgresql" + self.UTypeTwo(), "postgresql", "mssql" ) self.composite = self.variant.with_variant(self.UTypeThree(), "mysql") + def test_one_dialect_is_req(self): + with expect_raises_message( + exc.ArgumentError, "At least one dialect name is required" + ): + String().with_variant(VARCHAR()) + def test_illegal_dupe(self): v = self.UTypeOne().with_variant(self.UTypeTwo(), "postgresql") assert_raises_message( @@ -1547,6 +1553,9 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( self.variant, "UTYPETWO", dialect=dialects.postgresql.dialect() ) + self.assert_compile( + self.variant, "UTYPETWO", dialect=dialects.mssql.dialect() + ) def test_to_instance(self): self.assert_compile(