From 2afb138d310da41d17f9e3dc9fa9339b52e7a9a4 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 1 Aug 2024 15:51:00 -0400 Subject: [PATCH] escape percents for mysql enum and add suite tests Fixed issue in MySQL dialect where ENUM values that contained percent signs were not properly escaped for the driver. Fixes: #11479 Change-Id: I40d9aba619618603d3abb466f84a793d152b6788 --- doc/build/changelog/unreleased_20/11479.rst | 7 +++ lib/sqlalchemy/dialects/mysql/base.py | 2 + lib/sqlalchemy/testing/suite/test_types.py | 70 +++++++++++++++++++++ 3 files changed, 79 insertions(+) create mode 100644 doc/build/changelog/unreleased_20/11479.rst diff --git a/doc/build/changelog/unreleased_20/11479.rst b/doc/build/changelog/unreleased_20/11479.rst new file mode 100644 index 0000000000..fccaaf8026 --- /dev/null +++ b/doc/build/changelog/unreleased_20/11479.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, mysql + :tickets: 11479 + + Fixed issue in MySQL dialect where ENUM values that contained percent signs + were not properly escaped for the driver. + diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index af1a030ced..d5db02d278 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -2380,6 +2380,8 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def _visit_enumerated_values(self, name, type_, enumerated_values): quoted_enums = [] for e in enumerated_values: + if self.dialect.identifier_preparer._double_percents: + e = e.replace("%", "%%") quoted_enums.append("'%s'" % e.replace("'", "''")) return self._extend_string( type_, {}, "%s(%s)" % (name, ",".join(quoted_enums)) diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 4a7c1f199e..d4c5a2250d 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -32,6 +32,7 @@ from ... import case from ... import cast from ... import Date from ... import DateTime +from ... import Enum from ... import Float from ... import Integer from ... import Interval @@ -1918,6 +1919,74 @@ class JSONLegacyStringCastIndexTest( ) +class EnumTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __backend__ = True + + enum_values = "a", "b", "a%", "b%percent", "réveillé" + + datatype = Enum(*enum_values, name="myenum") + + @classmethod + def define_tables(cls, metadata): + Table( + "enum_table", + metadata, + Column("id", Integer, primary_key=True), + Column("enum_data", cls.datatype), + ) + + @testing.combinations(*enum_values, argnames="data") + def test_round_trip(self, data, connection): + connection.execute( + self.tables.enum_table.insert(), {"id": 1, "enum_data": data} + ) + + eq_( + connection.scalar( + select(self.tables.enum_table.c.enum_data).where( + self.tables.enum_table.c.id == 1 + ) + ), + data, + ) + + def test_round_trip_executemany(self, connection): + connection.execute( + self.tables.enum_table.insert(), + [ + {"id": 1, "enum_data": "b%percent"}, + {"id": 2, "enum_data": "réveillé"}, + {"id": 3, "enum_data": "b"}, + {"id": 4, "enum_data": "a%"}, + ], + ) + + eq_( + connection.scalars( + select(self.tables.enum_table.c.enum_data).order_by( + self.tables.enum_table.c.id + ) + ).all(), + ["b%percent", "réveillé", "b", "a%"], + ) + + @testing.requires.insert_executemany_returning + def test_round_trip_executemany_returning(self, connection): + result = connection.execute( + self.tables.enum_table.insert().returning( + self.tables.enum_table.c.enum_data + ), + [ + {"id": 1, "enum_data": "b%percent"}, + {"id": 2, "enum_data": "réveillé"}, + {"id": 3, "enum_data": "b"}, + {"id": 4, "enum_data": "a%"}, + ], + ) + + eq_(result.scalars().all(), ["b%percent", "réveillé", "b", "a%"]) + + class UuidTest(_LiteralRoundTripFixture, fixtures.TablesTest): __backend__ = True @@ -2066,6 +2135,7 @@ __all__ = ( "DateHistoricTest", "StringTest", "BooleanTest", + "EnumTest", "UuidTest", "NativeUUIDTest", ) -- 2.47.2