]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix sqlite regex for quoted fk, pk names
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Nov 2025 02:57:28 +0000 (22:57 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Nov 2025 02:57:28 +0000 (22:57 -0400)
Fixed issue where SQLite dialect would fail to reflect constraint names
that contained uppercase letters or other characters requiring quoting. The
regular expressions used to parse primary key, foreign key, and unique
constraint names from the ``CREATE TABLE`` statement have been updated to
properly handle both quoted and unquoted constraint names.

Fixes: #12954
Change-Id: If5c24f536795e5db867d857242013610a04638fc

lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/testing/suite/test_reflection.py

index a05d2c3602cc6ca60ce3ed4197e543945b73916c..3c7cc7d99f4962c8a4600cc8916548b28570bb33 100644 (file)
@@ -2533,9 +2533,12 @@ class SQLiteDialect(default.DefaultDialect):
         constraint_name = None
         table_data = self._get_table_sql(connection, table_name, schema=schema)
         if table_data:
-            PK_PATTERN = r"CONSTRAINT (\w+) PRIMARY KEY"
+            PK_PATTERN = r'CONSTRAINT +(?:"(.+?)"|(\w+)) +PRIMARY KEY'
             result = re.search(PK_PATTERN, table_data, re.I)
-            constraint_name = result.group(1) if result else None
+            if result:
+                constraint_name = result.group(1) or result.group(2)
+            else:
+                constraint_name = None
 
         cols = self.get_columns(connection, table_name, schema, **kw)
         # consider only pk columns. This also avoids sorting the cached
@@ -2635,7 +2638,7 @@ class SQLiteDialect(default.DefaultDialect):
             # so parsing the columns is really about matching it up to what
             # we already have.
             FK_PATTERN = (
-                r"(?:CONSTRAINT (\w+) +)?"
+                r'(?:CONSTRAINT +(?:"(.+?)"|(\w+)) +)?'
                 r"FOREIGN KEY *\( *(.+?) *\) +"
                 r'REFERENCES +(?:(?:"(.+?)")|([a-z0-9_]+)) *\( *((?:(?:"[^"]+"|[a-z0-9_]+) *(?:, *)?)+)\) *'  # noqa: E501
                 r"((?:ON (?:DELETE|UPDATE) "
@@ -2645,6 +2648,7 @@ class SQLiteDialect(default.DefaultDialect):
             )
             for match in re.finditer(FK_PATTERN, table_data, re.I):
                 (
+                    constraint_quoted_name,
                     constraint_name,
                     constrained_columns,
                     referred_quoted_name,
@@ -2653,7 +2657,8 @@ class SQLiteDialect(default.DefaultDialect):
                     onupdatedelete,
                     deferrable,
                     initially,
-                ) = match.group(1, 2, 3, 4, 5, 6, 7, 8)
+                ) = match.group(1, 2, 3, 4, 5, 6, 7, 8, 9)
+                constraint_name = constraint_quoted_name or constraint_name
                 constrained_columns = list(
                     self._find_cols_in_sig(constrained_columns)
                 )
@@ -2748,14 +2753,17 @@ class SQLiteDialect(default.DefaultDialect):
         def parse_uqs():
             if table_data is None:
                 return
-            UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)'
+            UNIQUE_PATTERN = (
+                r'(?:CONSTRAINT +(?:"(.+?)"|(\w+)) +)?UNIQUE *\((.+?)\)'
+            )
             INLINE_UNIQUE_PATTERN = (
                 r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?)[\t ]'
                 r"+[a-z0-9_ ]+?[\t ]+UNIQUE"
             )
 
             for match in re.finditer(UNIQUE_PATTERN, table_data, re.I):
-                name, cols = match.group(1, 2)
+                quoted_name, unquoted_name, cols = match.group(1, 2, 3)
+                name = quoted_name or unquoted_name
                 yield name, list(self._find_cols_in_sig(cols))
 
             # we need to match inlines as well, as we seek to differentiate
index 7da5e0541401eccdfab85ffc5147d5cc79625caa..86427a6a68257908ab95c9cad80b0a2ef5c64967 100644 (file)
@@ -1822,6 +1822,37 @@ class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest):
             self._required_pk_keys,
         )
 
+    @testing.combinations(
+        "PK_test_table",
+        "pk_test_table",
+        "mixedCasePK",
+        "pk.with.dots",
+        argnames="pk_name",
+    )
+    @testing.requires.primary_key_constraint_reflection
+    @testing.requires.reflects_pk_names
+    def test_get_pk_constraint_quoted_name(
+        self, connection, metadata, pk_name
+    ):
+        """Test that primary key constraint names with various casing are
+        properly reflected."""
+
+        Table(
+            "test_table",
+            metadata,
+            Column("id", Integer),
+            Column("data", String(50)),
+            sa.PrimaryKeyConstraint("id", name=pk_name),
+        )
+
+        metadata.create_all(connection)
+
+        insp = inspect(connection)
+        pk_cons = insp.get_pk_constraint("test_table")
+
+        eq_(pk_cons["name"], pk_name)
+        eq_(pk_cons["constrained_columns"], ["id"])
+
     @testing.combinations(
         (False,), (True, testing.requires.schemas), argnames="use_schema"
     )
@@ -1864,6 +1895,53 @@ class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest):
         no_cst = self.tables.no_constraints.name
         eq_(insp.get_foreign_keys(no_cst, schema=schema), [])
 
+    @testing.combinations(
+        "FK_users_id",
+        "fk_users_id",
+        "mixedCaseName",
+        "fk.with.dots",
+        argnames="fk_name",
+    )
+    @testing.requires.foreign_key_constraint_reflection
+    def test_get_foreign_keys_quoted_name(self, connection, metadata, fk_name):
+        """Test that foreign key constraint names with various casing are
+        properly reflected."""
+
+        Table(
+            "users_ref",
+            metadata,
+            Column("user_id", Integer, primary_key=True),
+            test_needs_fk=True,
+        )
+
+        Table(
+            "user_orders",
+            metadata,
+            Column("order_id", Integer, primary_key=True),
+            Column("user_id", Integer),
+            sa.ForeignKeyConstraint(
+                ["user_id"],
+                ["users_ref.user_id"],
+                name=fk_name,
+            ),
+            test_needs_fk=True,
+        )
+
+        metadata.create_all(connection)
+
+        insp = inspect(connection)
+        fkeys = insp.get_foreign_keys("user_orders")
+
+        eq_(len(fkeys), 1)
+        fkey = fkeys[0]
+
+        with testing.requires.named_constraints.fail_if():
+            eq_(fkey["name"], fk_name)
+
+        eq_(fkey["referred_table"], "users_ref")
+        eq_(fkey["referred_columns"], ["user_id"])
+        eq_(fkey["constrained_columns"], ["user_id"])
+
     @testing.requires.cross_schema_fk_reflection
     @testing.requires.schemas
     def test_get_inter_schema_foreign_keys(self, connection):
@@ -1950,6 +2028,38 @@ class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest):
         is_(list(t.indexes)[0].table, t)
         eq_(list(t.indexes)[0].name, ixname)
 
+    @testing.combinations(
+        "IX_test_data",
+        "ix_test_data",
+        "mixedCaseIndex",
+        "ix.with.dots",
+        argnames="idx_name",
+    )
+    @testing.requires.index_reflection
+    def test_get_indexes_quoted_name(self, connection, metadata, idx_name):
+        """Test that index names with various casing are properly reflected."""
+
+        t = Table(
+            "test_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", String(50)),
+        )
+        Index(idx_name, t.c.data)
+
+        metadata.create_all(connection)
+
+        insp = inspect(connection)
+        indexes = insp.get_indexes("test_table")
+
+        index_names = [idx["name"] for idx in indexes]
+        assert idx_name in index_names, f"Expected {idx_name} in {index_names}"
+
+        # Find the specific index
+        matching_idx = [idx for idx in indexes if idx["name"] == idx_name]
+        eq_(len(matching_idx), 1)
+        eq_(matching_idx[0]["column_names"], ["data"])
+
     @testing.requires.temp_table_reflection
     @testing.requires.unique_constraint_reflection
     def test_get_temp_table_unique_constraints(self, connection):
@@ -2068,6 +2178,37 @@ class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest):
         no_cst = self.tables.no_constraints.name
         eq_(insp.get_unique_constraints(no_cst, schema=schema), [])
 
+    @testing.combinations(
+        "UQ_email",
+        "uq_email",
+        "mixedCaseUQ",
+        "uq.with.dots",
+        argnames="uq_name",
+    )
+    @testing.requires.unique_constraint_reflection
+    def test_get_unique_constraints_quoted_name(
+        self, connection, metadata, uq_name
+    ):
+        """Test that unique constraint names with various casing are
+        properly reflected."""
+
+        Table(
+            "test_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("email", String(50)),
+            sa.UniqueConstraint("email", name=uq_name),
+        )
+
+        metadata.create_all(connection)
+
+        insp = inspect(connection)
+        uq_cons = insp.get_unique_constraints("test_table")
+
+        eq_(len(uq_cons), 1)
+        eq_(uq_cons[0]["name"], uq_name)
+        eq_(uq_cons[0]["column_names"], ["email"])
+
     @testing.requires.view_reflection
     @testing.combinations(
         (False,), (True, testing.requires.schemas), argnames="use_schema"