From: Federico Caselli Date: Sun, 7 Jul 2024 09:56:56 +0000 (+0200) Subject: Improve generated reflection in sqlite X-Git-Tag: rel_1_4_53~5 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=fdfb3a2842c3084e791bfb5d6e2e4369b8f7d8f1;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Improve generated reflection in sqlite Fixed reflection of computed column in SQLite to properly account for complex expressions. Fixes: #11582 Change-Id: I8e9fdda3e47c04b376973ee245b3175374a08f56 (cherry picked from commit e67a0b77a82667e2199e333bae0606d143fa228e) --- diff --git a/doc/build/changelog/unreleased_14/11582.rst b/doc/build/changelog/unreleased_14/11582.rst new file mode 100644 index 0000000000..935af9b244 --- /dev/null +++ b/doc/build/changelog/unreleased_14/11582.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: bug, reflection, sqlite + :tickets: 11582 + + Fixed reflection of computed column in SQLite to properly account + for complex expressions. diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index bcf38edc72..c171136ac2 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -2100,6 +2100,14 @@ class SQLiteDialect(default.DefaultDialect): tablesql = self._get_table_sql( connection, table_name, schema, **kw ) + # remove create table + match = re.match( + r"create table .*?\((.*)\)$", + tablesql.strip(), + re.DOTALL | re.IGNORECASE, + ) + assert match, "create table not found in %s" % tablesql + tablesql = match.group(1).strip() columns.append( self._get_column_info( @@ -2149,7 +2157,10 @@ class SQLiteDialect(default.DefaultDialect): if generated: sqltext = "" if tablesql: - pattern = r"[^,]*\s+AS\s+\(([^,]*)\)\s*(?:virtual|stored)?" + pattern = ( + r"[^,]*\s+GENERATED\s+ALWAYS\s+AS" + r"\s+\((.*)\)\s*(?:virtual|stored)?" + ) match = re.search( re.escape(name) + pattern, tablesql, re.IGNORECASE ) diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index 418bf9c657..12e607020e 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -54,6 +54,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.types import Boolean @@ -114,7 +115,7 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults): ) def test_cant_parse_datetime_message(self, connection): - for (typ, disp) in [ + for typ, disp in [ (Time, "time"), (DateTime, "datetime"), (Date, "date"), @@ -992,7 +993,6 @@ class AttachedDBTest(fixtures.TestBase): class SQLTest(fixtures.TestBase, AssertsCompiledSQL): - """Tests SQLite-dialect specific compilation.""" __dialect__ = sqlite.dialect() @@ -1387,7 +1387,6 @@ class OnConflictDDLTest(fixtures.TestBase, AssertsCompiledSQL): class InsertTest(fixtures.TestBase, AssertsExecutionResults): - """Tests inserts and autoincrement.""" __only_on__ = "sqlite" @@ -2385,8 +2384,8 @@ class ConstraintReflectionTest(fixtures.TestBase): [ { "unique": 0, - "name": u"ix_main_l_bar", - "column_names": [u"bar"], + "name": "ix_main_l_bar", + "column_names": ["bar"], "dialect_options": {}, } ], @@ -2586,7 +2585,6 @@ class ConstraintReflectionTest(fixtures.TestBase): class SavepointTest(fixtures.TablesTest): - """test that savepoints work when we use the correct event setup""" __only_on__ = "sqlite" @@ -3544,3 +3542,100 @@ class OnConflictTest(AssertsCompiledSQL, fixtures.TablesTest): conn.scalar(sql.select(bind_targets.c.data)), "new updated data processed", ) + + +class ComputedReflectionTest(fixtures.TestBase): + __only_on__ = "sqlite" + __backend__ = True + + @classmethod + def setup_test_class(cls): + tables = [ + """CREATE TABLE test1 ( + s VARCHAR, + x VARCHAR GENERATED ALWAYS AS (s || 'x') + );""", + """CREATE TABLE test2 ( + s VARCHAR, + x VARCHAR GENERATED ALWAYS AS (s || 'x'), + y VARCHAR GENERATED ALWAYS AS (s || 'y') + );""", + """CREATE TABLE test3 ( + s VARCHAR, + x INTEGER GENERATED ALWAYS AS (INSTR(s, ",")) + );""", + """CREATE TABLE test4 ( + s VARCHAR, + x INTEGER GENERATED ALWAYS AS (INSTR(s, ",")), + y INTEGER GENERATED ALWAYS AS (INSTR(x, ",")));""", + """CREATE TABLE test5 ( + s VARCHAR, + x VARCHAR GENERATED ALWAYS AS (s || 'x') STORED + );""", + """CREATE TABLE test6 ( + s VARCHAR, + x VARCHAR GENERATED ALWAYS AS (s || 'x') STORED, + y VARCHAR GENERATED ALWAYS AS (s || 'y') STORED + );""", + """CREATE TABLE test7 ( + s VARCHAR, + x INTEGER GENERATED ALWAYS AS (INSTR(s, ",")) STORED + );""", + """CREATE TABLE test8 ( + s VARCHAR, + x INTEGER GENERATED ALWAYS AS (INSTR(s, ",")) STORED, + y INTEGER GENERATED ALWAYS AS (INSTR(x, ",")) STORED + );""", + ] + + with testing.db.begin() as conn: + for ct in tables: + conn.exec_driver_sql(ct) + + @classmethod + def teardown_test_class(cls): + with testing.db.begin() as conn: + for tn in cls.res: + conn.exec_driver_sql("DROP TABLE %s" % tn) + + res = { + "test1": {"x": {"text": "s || 'x'", "stored": False}}, + "test2": { + "x": {"text": "s || 'x'", "stored": False}, + "y": {"text": "s || 'y'", "stored": False}, + }, + "test3": {"x": {"text": 'INSTR(s, ",")', "stored": False}}, + "test4": { + "x": {"text": 'INSTR(s, ",")', "stored": False}, + "y": {"text": 'INSTR(x, ",")', "stored": False}, + }, + "test5": {"x": {"text": "s || 'x'", "stored": True}}, + "test6": { + "x": {"text": "s || 'x'", "stored": True}, + "y": {"text": "s || 'y'", "stored": True}, + }, + "test7": {"x": {"text": 'INSTR(s, ",")', "stored": True}}, + "test8": { + "x": {"text": 'INSTR(s, ",")', "stored": True}, + "y": {"text": 'INSTR(x, ",")', "stored": True}, + }, + } + + def test_reflection(self, connection): + meta = MetaData() + meta.reflect(connection) + eq_(len(meta.tables), len(self.res)) + for tbl in meta.tables.values(): + data = self.res[tbl.name] + seen = set() + for col in tbl.c: + if col.name not in data: + is_(col.computed, None) + else: + info = data[col.name] + seen.add(col.name) + msg = "%s-%s" % (tbl.name, col.name) + is_true(bool(col.computed)) + eq_(col.computed.sqltext.text, info["text"], msg) + eq_(col.computed.persisted, info["stored"], msg) + eq_(seen, set(data.keys()))