]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
escape percents for mysql enum and add suite tests
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Aug 2024 19:51:00 +0000 (15:51 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Aug 2024 19:51:00 +0000 (15:51 -0400)
Fixed issue in MySQL dialect where ENUM values that contained percent signs
were not properly escaped for the driver.

Fixes: #11479
Change-Id: I40d9aba619618603d3abb466f84a793d152b6788

doc/build/changelog/unreleased_20/11479.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/testing/suite/test_types.py

diff --git a/doc/build/changelog/unreleased_20/11479.rst b/doc/build/changelog/unreleased_20/11479.rst
new file mode 100644 (file)
index 0000000..fccaaf8
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, mysql
+    :tickets: 11479
+
+    Fixed issue in MySQL dialect where ENUM values that contained percent signs
+    were not properly escaped for the driver.
+
index af1a030ced11d615d4ddcc42261250289ad2c811..d5db02d2781bd371d63a5d51025cd97a287aee3b 100644 (file)
@@ -2380,6 +2380,8 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
     def _visit_enumerated_values(self, name, type_, enumerated_values):
         quoted_enums = []
         for e in enumerated_values:
+            if self.dialect.identifier_preparer._double_percents:
+                e = e.replace("%", "%%")
             quoted_enums.append("'%s'" % e.replace("'", "''"))
         return self._extend_string(
             type_, {}, "%s(%s)" % (name, ",".join(quoted_enums))
index 4a7c1f199e134c1d89221a0cd44bbe00078158bc..d4c5a2250dc95804ae06df04d597a1f63893b252 100644 (file)
@@ -32,6 +32,7 @@ from ... import case
 from ... import cast
 from ... import Date
 from ... import DateTime
+from ... import Enum
 from ... import Float
 from ... import Integer
 from ... import Interval
@@ -1918,6 +1919,74 @@ class JSONLegacyStringCastIndexTest(
         )
 
 
+class EnumTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+    __backend__ = True
+
+    enum_values = "a", "b", "a%", "b%percent", "réveillé"
+
+    datatype = Enum(*enum_values, name="myenum")
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "enum_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("enum_data", cls.datatype),
+        )
+
+    @testing.combinations(*enum_values, argnames="data")
+    def test_round_trip(self, data, connection):
+        connection.execute(
+            self.tables.enum_table.insert(), {"id": 1, "enum_data": data}
+        )
+
+        eq_(
+            connection.scalar(
+                select(self.tables.enum_table.c.enum_data).where(
+                    self.tables.enum_table.c.id == 1
+                )
+            ),
+            data,
+        )
+
+    def test_round_trip_executemany(self, connection):
+        connection.execute(
+            self.tables.enum_table.insert(),
+            [
+                {"id": 1, "enum_data": "b%percent"},
+                {"id": 2, "enum_data": "réveillé"},
+                {"id": 3, "enum_data": "b"},
+                {"id": 4, "enum_data": "a%"},
+            ],
+        )
+
+        eq_(
+            connection.scalars(
+                select(self.tables.enum_table.c.enum_data).order_by(
+                    self.tables.enum_table.c.id
+                )
+            ).all(),
+            ["b%percent", "réveillé", "b", "a%"],
+        )
+
+    @testing.requires.insert_executemany_returning
+    def test_round_trip_executemany_returning(self, connection):
+        result = connection.execute(
+            self.tables.enum_table.insert().returning(
+                self.tables.enum_table.c.enum_data
+            ),
+            [
+                {"id": 1, "enum_data": "b%percent"},
+                {"id": 2, "enum_data": "réveillé"},
+                {"id": 3, "enum_data": "b"},
+                {"id": 4, "enum_data": "a%"},
+            ],
+        )
+
+        eq_(result.scalars().all(), ["b%percent", "réveillé", "b", "a%"])
+
+
 class UuidTest(_LiteralRoundTripFixture, fixtures.TablesTest):
     __backend__ = True
 
@@ -2066,6 +2135,7 @@ __all__ = (
     "DateHistoricTest",
     "StringTest",
     "BooleanTest",
+    "EnumTest",
     "UuidTest",
     "NativeUUIDTest",
 )