]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve generated reflection in sqlite
authorFederico Caselli <cfederico87@gmail.com>
Sun, 7 Jul 2024 09:56:56 +0000 (11:56 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Sun, 7 Jul 2024 21:25:28 +0000 (23:25 +0200)
Fixed reflection of computed column in SQLite to properly account
for complex expressions.

Fixes: #11582
Change-Id: I8e9fdda3e47c04b376973ee245b3175374a08f56

doc/build/changelog/unreleased_14/11582.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/sqlite/base.py
test/dialect/test_sqlite.py

diff --git a/doc/build/changelog/unreleased_14/11582.rst b/doc/build/changelog/unreleased_14/11582.rst
new file mode 100644 (file)
index 0000000..935af9b
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, reflection, sqlite
+    :tickets: 11582
+
+    Fixed reflection of computed column in SQLite to properly account
+    for complex expressions.
index 6db8214652a9915281129ee3fc2396d05bdd4e2e..8e3f7a560e0c7af5a424b366a6321727a04a49ee 100644 (file)
@@ -2231,6 +2231,14 @@ class SQLiteDialect(default.DefaultDialect):
                 tablesql = self._get_table_sql(
                     connection, table_name, schema, **kw
                 )
+                # remove create table
+                match = re.match(
+                    r"create table .*?\((.*)\)$",
+                    tablesql.strip(),
+                    re.DOTALL | re.IGNORECASE,
+                )
+                assert match, f"create table not found in {tablesql}"
+                tablesql = match.group(1).strip()
 
             columns.append(
                 self._get_column_info(
@@ -2285,7 +2293,10 @@ class SQLiteDialect(default.DefaultDialect):
         if generated:
             sqltext = ""
             if tablesql:
-                pattern = r"[^,]*\s+AS\s+\(([^,]*)\)\s*(?:virtual|stored)?"
+                pattern = (
+                    r"[^,]*\s+GENERATED\s+ALWAYS\s+AS"
+                    r"\s+\((.*)\)\s*(?:virtual|stored)?"
+                )
                 match = re.search(
                     re.escape(name) + pattern, tablesql, re.IGNORECASE
                 )
index 1289cf9ba0d9331fd4cf809531126575fec62d82..8dedadbde9dc7d9b7c2d249f0dcb8655c6309ef1 100644 (file)
@@ -53,6 +53,7 @@ from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.types import Boolean
 from sqlalchemy.types import Date
@@ -3554,3 +3555,100 @@ class ReflectInternalSchemaTables(fixtures.TablesTest):
             eq_(res, ["sqlitetempview"])
         finally:
             connection.exec_driver_sql("DROP VIEW sqlitetempview")
+
+
+class ComputedReflectionTest(fixtures.TestBase):
+    __only_on__ = "sqlite"
+    __backend__ = True
+
+    @classmethod
+    def setup_test_class(cls):
+        tables = [
+            """CREATE TABLE test1 (
+                s VARCHAR,
+                x VARCHAR GENERATED ALWAYS AS (s || 'x')
+            );""",
+            """CREATE TABLE test2 (
+                s VARCHAR,
+                x VARCHAR GENERATED ALWAYS AS (s || 'x'),
+                y VARCHAR GENERATED ALWAYS AS (s || 'y')
+            );""",
+            """CREATE TABLE test3 (
+                s VARCHAR,
+                x INTEGER GENERATED ALWAYS AS (INSTR(s, ","))
+            );""",
+            """CREATE TABLE test4 (
+                s VARCHAR,
+                x INTEGER GENERATED ALWAYS AS (INSTR(s, ",")),
+                y INTEGER GENERATED ALWAYS AS (INSTR(x, ",")));""",
+            """CREATE TABLE test5 (
+                s VARCHAR,
+                x VARCHAR GENERATED ALWAYS AS (s || 'x') STORED
+            );""",
+            """CREATE TABLE test6 (
+                s VARCHAR,
+                x VARCHAR GENERATED ALWAYS AS (s || 'x') STORED,
+                y VARCHAR GENERATED ALWAYS AS (s || 'y') STORED
+            );""",
+            """CREATE TABLE test7 (
+                s VARCHAR,
+                x INTEGER GENERATED ALWAYS AS (INSTR(s, ",")) STORED
+            );""",
+            """CREATE TABLE test8 (
+                s VARCHAR,
+                x INTEGER GENERATED ALWAYS AS (INSTR(s, ",")) STORED,
+                y INTEGER GENERATED ALWAYS AS (INSTR(x, ",")) STORED
+            );""",
+        ]
+
+        with testing.db.begin() as conn:
+            for ct in tables:
+                conn.exec_driver_sql(ct)
+
+    @classmethod
+    def teardown_test_class(cls):
+        with testing.db.begin() as conn:
+            for tn in cls.res:
+                conn.exec_driver_sql(f"DROP TABLE {tn}")
+
+    res = {
+        "test1": {"x": {"text": "s || 'x'", "stored": False}},
+        "test2": {
+            "x": {"text": "s || 'x'", "stored": False},
+            "y": {"text": "s || 'y'", "stored": False},
+        },
+        "test3": {"x": {"text": 'INSTR(s, ",")', "stored": False}},
+        "test4": {
+            "x": {"text": 'INSTR(s, ",")', "stored": False},
+            "y": {"text": 'INSTR(x, ",")', "stored": False},
+        },
+        "test5": {"x": {"text": "s || 'x'", "stored": True}},
+        "test6": {
+            "x": {"text": "s || 'x'", "stored": True},
+            "y": {"text": "s || 'y'", "stored": True},
+        },
+        "test7": {"x": {"text": 'INSTR(s, ",")', "stored": True}},
+        "test8": {
+            "x": {"text": 'INSTR(s, ",")', "stored": True},
+            "y": {"text": 'INSTR(x, ",")', "stored": True},
+        },
+    }
+
+    def test_reflection(self, connection):
+        meta = MetaData()
+        meta.reflect(connection)
+        eq_(len(meta.tables), len(self.res))
+        for tbl in meta.tables.values():
+            data = self.res[tbl.name]
+            seen = set()
+            for col in tbl.c:
+                if col.name not in data:
+                    is_(col.computed, None)
+                else:
+                    info = data[col.name]
+                    seen.add(col.name)
+                    msg = f"{tbl.name}-{col.name}"
+                    is_true(bool(col.computed))
+                    eq_(col.computed.sqltext.text, info["text"], msg)
+                    eq_(col.computed.persisted, info["stored"], msg)
+            eq_(seen, data.keys())