]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix sqlite regex for quoted fk, pk names
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Nov 2025 02:57:28 +0000 (22:57 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Nov 2025 02:58:52 +0000 (22:58 -0400)
Fixed issue where SQLite dialect would fail to reflect constraint names
that contained uppercase letters or other characters requiring quoting. The
regular expressions used to parse primary key, foreign key, and unique
constraint names from the ``CREATE TABLE`` statement have been updated to
properly handle both quoted and unquoted constraint names.

Fixes: #12954
Change-Id: If5c24f536795e5db867d857242013610a04638fc
(cherry picked from commit cdaf1824316ba6fa7b52164b50cd9fd4aeb2c41f)

lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/testing/suite/test_reflection.py

index 3d786b8129f06d0002ef46e37af730a6681af0c7..dc13ed45ad9942776352e5a7eeb2dd6a7d641e0f 100644 (file)
@@ -2541,9 +2541,12 @@ class SQLiteDialect(default.DefaultDialect):
         constraint_name = None
         table_data = self._get_table_sql(connection, table_name, schema=schema)
         if table_data:
-            PK_PATTERN = r"CONSTRAINT (\w+) PRIMARY KEY"
+            PK_PATTERN = r'CONSTRAINT +(?:"(.+?)"|(\w+)) +PRIMARY KEY'
             result = re.search(PK_PATTERN, table_data, re.I)
-            constraint_name = result.group(1) if result else None
+            if result:
+                constraint_name = result.group(1) or result.group(2)
+            else:
+                constraint_name = None
 
         cols = self.get_columns(connection, table_name, schema, **kw)
         # consider only pk columns. This also avoids sorting the cached
@@ -2643,7 +2646,7 @@ class SQLiteDialect(default.DefaultDialect):
             # so parsing the columns is really about matching it up to what
             # we already have.
             FK_PATTERN = (
-                r"(?:CONSTRAINT (\w+) +)?"
+                r'(?:CONSTRAINT +(?:"(.+?)"|(\w+)) +)?'
                 r"FOREIGN KEY *\( *(.+?) *\) +"
                 r'REFERENCES +(?:(?:"(.+?)")|([a-z0-9_]+)) *\( *((?:(?:"[^"]+"|[a-z0-9_]+) *(?:, *)?)+)\) *'  # noqa: E501
                 r"((?:ON (?:DELETE|UPDATE) "
@@ -2653,6 +2656,7 @@ class SQLiteDialect(default.DefaultDialect):
             )
             for match in re.finditer(FK_PATTERN, table_data, re.I):
                 (
+                    constraint_quoted_name,
                     constraint_name,
                     constrained_columns,
                     referred_quoted_name,
@@ -2661,7 +2665,8 @@ class SQLiteDialect(default.DefaultDialect):
                     onupdatedelete,
                     deferrable,
                     initially,
-                ) = match.group(1, 2, 3, 4, 5, 6, 7, 8)
+                ) = match.group(1, 2, 3, 4, 5, 6, 7, 8, 9)
+                constraint_name = constraint_quoted_name or constraint_name
                 constrained_columns = list(
                     self._find_cols_in_sig(constrained_columns)
                 )
@@ -2756,14 +2761,17 @@ class SQLiteDialect(default.DefaultDialect):
         def parse_uqs():
             if table_data is None:
                 return
-            UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)'
+            UNIQUE_PATTERN = (
+                r'(?:CONSTRAINT +(?:"(.+?)"|(\w+)) +)?UNIQUE *\((.+?)\)'
+            )
             INLINE_UNIQUE_PATTERN = (
                 r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?)[\t ]'
                 r"+[a-z0-9_ ]+?[\t ]+UNIQUE"
             )
 
             for match in re.finditer(UNIQUE_PATTERN, table_data, re.I):
-                name, cols = match.group(1, 2)
+                quoted_name, unquoted_name, cols = match.group(1, 2, 3)
+                name = quoted_name or unquoted_name
                 yield name, list(self._find_cols_in_sig(cols))
 
             # we need to match inlines as well, as we seek to differentiate
index dda34f249c2652bdb56ab59ddfe6f0ab4e29dc99..8ba588d2b5af5c4ba61e9c4d44ebb23bae18ec1a 100644 (file)
@@ -1821,6 +1821,37 @@ class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest):
             self._required_pk_keys,
         )
 
+    @testing.combinations(
+        "PK_test_table",
+        "pk_test_table",
+        "mixedCasePK",
+        "pk.with.dots",
+        argnames="pk_name",
+    )
+    @testing.requires.primary_key_constraint_reflection
+    @testing.requires.reflects_pk_names
+    def test_get_pk_constraint_quoted_name(
+        self, connection, metadata, pk_name
+    ):
+        """Test that primary key constraint names with various casing are
+        properly reflected."""
+
+        Table(
+            "test_table",
+            metadata,
+            Column("id", Integer),
+            Column("data", String(50)),
+            sa.PrimaryKeyConstraint("id", name=pk_name),
+        )
+
+        metadata.create_all(connection)
+
+        insp = inspect(connection)
+        pk_cons = insp.get_pk_constraint("test_table")
+
+        eq_(pk_cons["name"], pk_name)
+        eq_(pk_cons["constrained_columns"], ["id"])
+
     @testing.combinations(
         (False,), (True, testing.requires.schemas), argnames="use_schema"
     )
@@ -1863,6 +1894,53 @@ class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest):
         no_cst = self.tables.no_constraints.name
         eq_(insp.get_foreign_keys(no_cst, schema=schema), [])
 
+    @testing.combinations(
+        "FK_users_id",
+        "fk_users_id",
+        "mixedCaseName",
+        "fk.with.dots",
+        argnames="fk_name",
+    )
+    @testing.requires.foreign_key_constraint_reflection
+    def test_get_foreign_keys_quoted_name(self, connection, metadata, fk_name):
+        """Test that foreign key constraint names with various casing are
+        properly reflected."""
+
+        Table(
+            "users_ref",
+            metadata,
+            Column("user_id", Integer, primary_key=True),
+            test_needs_fk=True,
+        )
+
+        Table(
+            "user_orders",
+            metadata,
+            Column("order_id", Integer, primary_key=True),
+            Column("user_id", Integer),
+            sa.ForeignKeyConstraint(
+                ["user_id"],
+                ["users_ref.user_id"],
+                name=fk_name,
+            ),
+            test_needs_fk=True,
+        )
+
+        metadata.create_all(connection)
+
+        insp = inspect(connection)
+        fkeys = insp.get_foreign_keys("user_orders")
+
+        eq_(len(fkeys), 1)
+        fkey = fkeys[0]
+
+        with testing.requires.named_constraints.fail_if():
+            eq_(fkey["name"], fk_name)
+
+        eq_(fkey["referred_table"], "users_ref")
+        eq_(fkey["referred_columns"], ["user_id"])
+        eq_(fkey["constrained_columns"], ["user_id"])
+
     @testing.requires.cross_schema_fk_reflection
     @testing.requires.schemas
     def test_get_inter_schema_foreign_keys(self, connection):
@@ -1949,6 +2027,38 @@ class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest):
         is_(list(t.indexes)[0].table, t)
         eq_(list(t.indexes)[0].name, ixname)
 
+    @testing.combinations(
+        "IX_test_data",
+        "ix_test_data",
+        "mixedCaseIndex",
+        "ix.with.dots",
+        argnames="idx_name",
+    )
+    @testing.requires.index_reflection
+    def test_get_indexes_quoted_name(self, connection, metadata, idx_name):
+        """Test that index names with various casing are properly reflected."""
+
+        t = Table(
+            "test_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", String(50)),
+        )
+        Index(idx_name, t.c.data)
+
+        metadata.create_all(connection)
+
+        insp = inspect(connection)
+        indexes = insp.get_indexes("test_table")
+
+        index_names = [idx["name"] for idx in indexes]
+        assert idx_name in index_names, f"Expected {idx_name} in {index_names}"
+
+        # Find the specific index
+        matching_idx = [idx for idx in indexes if idx["name"] == idx_name]
+        eq_(len(matching_idx), 1)
+        eq_(matching_idx[0]["column_names"], ["data"])
+
     @testing.requires.temp_table_reflection
     @testing.requires.unique_constraint_reflection
     def test_get_temp_table_unique_constraints(self, connection):
@@ -2067,6 +2177,37 @@ class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest):
         no_cst = self.tables.no_constraints.name
         eq_(insp.get_unique_constraints(no_cst, schema=schema), [])
 
+    @testing.combinations(
+        "UQ_email",
+        "uq_email",
+        "mixedCaseUQ",
+        "uq.with.dots",
+        argnames="uq_name",
+    )
+    @testing.requires.unique_constraint_reflection
+    def test_get_unique_constraints_quoted_name(
+        self, connection, metadata, uq_name
+    ):
+        """Test that unique constraint names with various casing are
+        properly reflected."""
+
+        Table(
+            "test_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("email", String(50)),
+            sa.UniqueConstraint("email", name=uq_name),
+        )
+
+        metadata.create_all(connection)
+
+        insp = inspect(connection)
+        uq_cons = insp.get_unique_constraints("test_table")
+
+        eq_(len(uq_cons), 1)
+        eq_(uq_cons[0]["name"], uq_name)
+        eq_(uq_cons[0]["column_names"], ["email"])
+
     @testing.requires.view_reflection
     @testing.combinations(
         (False,), (True, testing.requires.schemas), argnames="use_schema"