]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve generated reflection in sqlite
authorFederico Caselli <cfederico87@gmail.com>
Sun, 7 Jul 2024 09:56:56 +0000 (11:56 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 9 Jul 2024 15:51:54 +0000 (11:51 -0400)
Fixed reflection of computed column in SQLite to properly account
for complex expressions.

Fixes: #11582
Change-Id: I8e9fdda3e47c04b376973ee245b3175374a08f56
(cherry picked from commit e67a0b77a82667e2199e333bae0606d143fa228e)

doc/build/changelog/unreleased_14/11582.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/sqlite/base.py
test/dialect/test_sqlite.py

diff --git a/doc/build/changelog/unreleased_14/11582.rst b/doc/build/changelog/unreleased_14/11582.rst
new file mode 100644 (file)
index 0000000..935af9b
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, reflection, sqlite
+    :tickets: 11582
+
+    Fixed reflection of computed column in SQLite to properly account
+    for complex expressions.
index bcf38edc7297d1c37efd927b4944f7542cfdbe73..c171136ac2b1b0e9b36f359f854dbd2443f3a3e5 100644 (file)
@@ -2100,6 +2100,14 @@ class SQLiteDialect(default.DefaultDialect):
                 tablesql = self._get_table_sql(
                     connection, table_name, schema, **kw
                 )
+                # remove create table
+                match = re.match(
+                    r"create table .*?\((.*)\)$",
+                    tablesql.strip(),
+                    re.DOTALL | re.IGNORECASE,
+                )
+                assert match, "create table not found in %s" % tablesql
+                tablesql = match.group(1).strip()
 
             columns.append(
                 self._get_column_info(
@@ -2149,7 +2157,10 @@ class SQLiteDialect(default.DefaultDialect):
         if generated:
             sqltext = ""
             if tablesql:
-                pattern = r"[^,]*\s+AS\s+\(([^,]*)\)\s*(?:virtual|stored)?"
+                pattern = (
+                    r"[^,]*\s+GENERATED\s+ALWAYS\s+AS"
+                    r"\s+\((.*)\)\s*(?:virtual|stored)?"
+                )
                 match = re.search(
                     re.escape(name) + pattern, tablesql, re.IGNORECASE
                 )
index 418bf9c6575f8bcb2d82d8c170ca7ff88bbe047d..12e607020e0c1ce07d9d0249f877e4d4cd9b932e 100644 (file)
@@ -54,6 +54,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.types import Boolean
@@ -114,7 +115,7 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults):
         )
 
     def test_cant_parse_datetime_message(self, connection):
-        for (typ, disp) in [
+        for typ, disp in [
             (Time, "time"),
             (DateTime, "datetime"),
             (Date, "date"),
@@ -992,7 +993,6 @@ class AttachedDBTest(fixtures.TestBase):
 
 
 class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
-
     """Tests SQLite-dialect specific compilation."""
 
     __dialect__ = sqlite.dialect()
@@ -1387,7 +1387,6 @@ class OnConflictDDLTest(fixtures.TestBase, AssertsCompiledSQL):
 
 
 class InsertTest(fixtures.TestBase, AssertsExecutionResults):
-
     """Tests inserts and autoincrement."""
 
     __only_on__ = "sqlite"
@@ -2385,8 +2384,8 @@ class ConstraintReflectionTest(fixtures.TestBase):
             [
                 {
                     "unique": 0,
-                    "name": u"ix_main_l_bar",
-                    "column_names": [u"bar"],
+                    "name": "ix_main_l_bar",
+                    "column_names": ["bar"],
                     "dialect_options": {},
                 }
             ],
@@ -2586,7 +2585,6 @@ class ConstraintReflectionTest(fixtures.TestBase):
 
 
 class SavepointTest(fixtures.TablesTest):
-
     """test that savepoints work when we use the correct event setup"""
 
     __only_on__ = "sqlite"
@@ -3544,3 +3542,100 @@ class OnConflictTest(AssertsCompiledSQL, fixtures.TablesTest):
             conn.scalar(sql.select(bind_targets.c.data)),
             "new updated data processed",
         )
+
+
+class ComputedReflectionTest(fixtures.TestBase):
+    __only_on__ = "sqlite"
+    __backend__ = True
+
+    @classmethod
+    def setup_test_class(cls):
+        tables = [
+            """CREATE TABLE test1 (
+                s VARCHAR,
+                x VARCHAR GENERATED ALWAYS AS (s || 'x')
+            );""",
+            """CREATE TABLE test2 (
+                s VARCHAR,
+                x VARCHAR GENERATED ALWAYS AS (s || 'x'),
+                y VARCHAR GENERATED ALWAYS AS (s || 'y')
+            );""",
+            """CREATE TABLE test3 (
+                s VARCHAR,
+                x INTEGER GENERATED ALWAYS AS (INSTR(s, ","))
+            );""",
+            """CREATE TABLE test4 (
+                s VARCHAR,
+                x INTEGER GENERATED ALWAYS AS (INSTR(s, ",")),
+                y INTEGER GENERATED ALWAYS AS (INSTR(x, ",")));""",
+            """CREATE TABLE test5 (
+                s VARCHAR,
+                x VARCHAR GENERATED ALWAYS AS (s || 'x') STORED
+            );""",
+            """CREATE TABLE test6 (
+                s VARCHAR,
+                x VARCHAR GENERATED ALWAYS AS (s || 'x') STORED,
+                y VARCHAR GENERATED ALWAYS AS (s || 'y') STORED
+            );""",
+            """CREATE TABLE test7 (
+                s VARCHAR,
+                x INTEGER GENERATED ALWAYS AS (INSTR(s, ",")) STORED
+            );""",
+            """CREATE TABLE test8 (
+                s VARCHAR,
+                x INTEGER GENERATED ALWAYS AS (INSTR(s, ",")) STORED,
+                y INTEGER GENERATED ALWAYS AS (INSTR(x, ",")) STORED
+            );""",
+        ]
+
+        with testing.db.begin() as conn:
+            for ct in tables:
+                conn.exec_driver_sql(ct)
+
+    @classmethod
+    def teardown_test_class(cls):
+        with testing.db.begin() as conn:
+            for tn in cls.res:
+                conn.exec_driver_sql("DROP TABLE %s" % tn)
+
+    res = {
+        "test1": {"x": {"text": "s || 'x'", "stored": False}},
+        "test2": {
+            "x": {"text": "s || 'x'", "stored": False},
+            "y": {"text": "s || 'y'", "stored": False},
+        },
+        "test3": {"x": {"text": 'INSTR(s, ",")', "stored": False}},
+        "test4": {
+            "x": {"text": 'INSTR(s, ",")', "stored": False},
+            "y": {"text": 'INSTR(x, ",")', "stored": False},
+        },
+        "test5": {"x": {"text": "s || 'x'", "stored": True}},
+        "test6": {
+            "x": {"text": "s || 'x'", "stored": True},
+            "y": {"text": "s || 'y'", "stored": True},
+        },
+        "test7": {"x": {"text": 'INSTR(s, ",")', "stored": True}},
+        "test8": {
+            "x": {"text": 'INSTR(s, ",")', "stored": True},
+            "y": {"text": 'INSTR(x, ",")', "stored": True},
+        },
+    }
+
+    def test_reflection(self, connection):
+        meta = MetaData()
+        meta.reflect(connection)
+        eq_(len(meta.tables), len(self.res))
+        for tbl in meta.tables.values():
+            data = self.res[tbl.name]
+            seen = set()
+            for col in tbl.c:
+                if col.name not in data:
+                    is_(col.computed, None)
+                else:
+                    info = data[col.name]
+                    seen.add(col.name)
+                    msg = "%s-%s" % (tbl.name, col.name)
+                    is_true(bool(col.computed))
+                    eq_(col.computed.sqltext.text, info["text"], msg)
+                    eq_(col.computed.persisted, info["stored"], msg)
+            eq_(seen, set(data.keys()))