From: CaselIT Date: Mon, 4 Nov 2019 22:11:21 +0000 (-0500) Subject: Support for generated columns X-Git-Tag: rel_1_3_11~9^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=16eb92ea0905196c88f4db19f83cd2ca02a57c5b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support for generated columns Added DDL support for "computed columns"; these are DDL column specifications for columns that have a server-computed value, either upon SELECT (known as "virtual") or at the point of which they are INSERTed or UPDATEd (known as "stored"). Support is established for Postgresql, MySQL, Oracle SQL Server and Firebird. Thanks to Federico Caselli for lots of work on this one. ORM round trip tests included. The ORM makes use of existing FetchedValue support and no additional ORM logic is present for the basic feature. It has been observed that Oracle RETURNING does not return the new value of a computed column upon UPDATE; it returns the prior value. As this is very dangerous, a warning is emitted if a computed column is rendered into the RETURNING clause of an UPDATE statement. Fixes: #4894 Closes: #4928 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/4928 Pull-request-sha: d39c521d5ac6ebfb4fb5b53846451de79752e64c Change-Id: I2610b2999a5b1b127ed927dcdaeee98b769643ce (cherry picked from commit 602d1e6dfd538980bb8d513867b17dbc2b4b92dd) --- diff --git a/doc/build/changelog/unreleased_13/4894.rst b/doc/build/changelog/unreleased_13/4894.rst new file mode 100644 index 0000000000..ee0f6f8122 --- /dev/null +++ b/doc/build/changelog/unreleased_13/4894.rst @@ -0,0 +1,15 @@ +.. change:: + :tags: usecase, schema + :tickets: 4894 + + Added DDL support for "computed columns"; these are DDL column + specifications for columns that have a server-computed value, either upon + SELECT (known as "virtual") or at the point of which they are INSERTed or + UPDATEd (known as "stored"). Support is established for Postgresql, MySQL, + Oracle SQL Server and Firebird. Thanks to Federico Caselli for lots of work + on this one. + + .. seealso:: + + :ref:`computed_ddl` + diff --git a/doc/build/core/defaults.rst b/doc/build/core/defaults.rst index 897a215f12..23be2b24a1 100644 --- a/doc/build/core/defaults.rst +++ b/doc/build/core/defaults.rst @@ -536,9 +536,105 @@ including the default schema, if any. :ref:`oracle_returning` - in the Oracle dialect documentation +.. _computed_ddl: + +Computed (GENERATED ALWAYS AS) Columns +-------------------------------------- + +.. versionadded:: 1.3.11 + +The :class:`.Computed` construct allows a :class:`.Column` to be declared in +DDL as a "GENERATED ALWAYS AS" column, that is, one which has a value that is +computed by the database server. The construct accepts a SQL expression +typically declared textually using a string or the :func:`.text` construct, in +a similar manner as that of :class:`.CheckConstraint`. The SQL expression is +then interpreted by the database server in order to determine the value for the +column within a row. + +Example:: + + from sqlalchemy import Table, Column, MetaData, Integer, Computed + + metadata = MetaData() + + square = Table( + "square", + metadata, + Column("id", Integer, primary_key=True), + Column("side", Integer), + Column("area", Integer, Computed("side * side")), + Column("perimeter", Integer, Computed("4 * side")), + ) + +The DDL for the ``square`` table when run on a PostgreSQL 12 backend will look +like:: + + CREATE TABLE square ( + id SERIAL NOT NULL, + side INTEGER, + area INTEGER GENERATED ALWAYS AS (side * side) STORED, + perimeter INTEGER GENERATED ALWAYS AS (4 * side) STORED, + PRIMARY KEY (id) + ) + +Whether the value is persisted upon INSERT and UPDATE, or if it is calculated +on fetch, is an implementation detail of the database; the former is known as +"stored" and the latter is known as "virtual". Some database implementations +support both, but some only support one or the other. The optional +:paramref:`.Computed.persisted` flag may be specified as ``True`` or ``False`` +to indicate if the "STORED" or "VIRTUAL" keyword should be rendered in DDL, +however this will raise an error if the keyword is not supported by the target +backend; leaving it unset will use a working default for the target backend. + +The :class:`.Computed` construct is a subclass of the :class:`.FetchedValue` +object, and will set itself up as both the "server default" and "server +onupdate" generator for the target :class:`.Column`, meaning it will be treated +as a default generating column when INSERT and UPDATE statements are generated, +as well as that it will be fetched as a generating column when using the ORM. +This includes that it will be part of the RETURNING clause of the database +for databases which support RETURNING and the generated values are to be +eagerly fetched. + +.. note:: A :class:`.Column` that is defined with the :class:`.Computed` + construct may not store any value outside of that which the server applies + to it; SQLAlchemy's behavior when a value is passed for such a column + to be written in INSERT or UPDATE is currently that the value will be + ignored. + +"GENERATED ALWAYS AS" is currently known to be supported by: + +* MySQL version 5.7 and onwards + +* MariaDB 10.x series and onwards + +* PostgreSQL as of version 12 + +* Oracle - with the caveat that RETURNING does not work correctly with UPDATE + (a warning will be emitted to this effect when the UPDATE..RETURNING that + includes a computed column is rendered) + +* Microsoft SQL Server + +* Firebird + +When :class:`.Computed` is used with an unsupported backend, if the target +dialect does not support it, a :class:`.CompileError` is raised when attempting +to render the construct. Otherwise, if the dialect supports it but the +particular database server version in use does not, then a subclass of +:class:`.DBAPIError`, usually :class:`.OperationalError`, is raised when the +DDL is emitted to the database. + +.. seealso:: + + :class:`.Computed` + Default Objects API ------------------- +.. autoclass:: Computed + :members: + + .. autoclass:: ColumnDefault diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index eb2ff1fd92..ac6431b925 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -11,6 +11,7 @@ from .schema import BLANK_SCHEMA # noqa from .schema import CheckConstraint # noqa from .schema import Column # noqa from .schema import ColumnDefault # noqa +from .schema import Computed # noqa from .schema import Constraint # noqa from .schema import DDL # noqa from .schema import DefaultClause # noqa @@ -122,7 +123,7 @@ from .engine import create_engine # noqa nosort from .engine import engine_from_config # noqa nosort -__version__ = '1.3.11' +__version__ = "1.3.11" def __go(lcls): diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index a00c127d29..82afccf8a6 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -581,6 +581,17 @@ class FBDDLCompiler(sql.compiler.DDLCompiler): drop.element ) + def visit_computed_column(self, generated): + if generated.persisted is not None: + raise exc.CompileError( + "Firebird computed columns do not support a persistence " + "method setting; set the 'persisted' flag to None for " + "Firebird support." + ) + return "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): """Install Firebird specific reserved words.""" diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index e4f1d205d1..3f46816356 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1913,13 +1913,15 @@ class MSSQLStrictCompiler(MSSQLCompiler): class MSDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): - colspec = ( - self.preparer.format_column(column) - + " " - + self.dialect.type_compiler.process( + colspec = self.preparer.format_column(column) + + # type is not accepted in a computed column + if column.computed is not None: + colspec += " " + self.process(column.computed) + else: + colspec += " " + self.dialect.type_compiler.process( column.type, type_expression=column ) - ) if column.nullable is not None: if ( @@ -1929,7 +1931,8 @@ class MSDDLCompiler(compiler.DDLCompiler): or column.autoincrement is True ): colspec += " NOT NULL" - else: + elif column.computed is None: + # don't specify "NULL" for computed columns colspec += " NULL" if column.table is None: @@ -2081,6 +2084,15 @@ class MSDDLCompiler(compiler.DDLCompiler): text += self.define_constraint_deferrability(constraint) return text + def visit_computed_column(self, generated): + text = "AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + # explicitly check for True|False since None means server default + if generated.persisted is True: + text += " PERSISTED" + return text + class MSIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index e02476e0e2..2112aeae8c 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1501,6 +1501,9 @@ class MySQLDDLCompiler(compiler.DDLCompiler): ), ] + if column.computed is not None: + colspec.append(self.process(column.computed)) + is_timestamp = isinstance( column.type._unwrapped_dialect_impl(self.dialect), sqltypes.TIMESTAMP, diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index f7e511be2e..d0df986ece 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -454,6 +454,7 @@ columns for non-unique indexes, all but the last column for unique indexes). from itertools import groupby import re +from ... import Computed from ... import exc from ... import schema as sa_schema from ... import sql @@ -912,6 +913,16 @@ class OracleCompiler(compiler.SQLCompiler): for i, column in enumerate( expression._select_iterables(returning_cols) ): + if self.isupdate and isinstance(column.server_default, Computed): + util.warn( + "Computed columns don't work with Oracle UPDATE " + "statements that use RETURNING; the value of the column " + "*before* the UPDATE takes place is returned. It is " + "advised to not use RETURNING with an Oracle computed " + "column. Consider setting implicit_returning to False on " + "the Table object in order to avoid implicit RETURNING " + "clauses from being generated for this Table." + ) if column.type._has_column_expression: col_expr = column.type.column_expression(column) else: @@ -1151,6 +1162,19 @@ class OracleDDLCompiler(compiler.DDLCompiler): return "".join(table_opts) + def visit_computed_column(self, generated): + text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + if generated.persisted is True: + raise exc.CompileError( + "Oracle computed columns do not support 'stored' persistence; " + "set the 'persisted' flag to None or False for Oracle support." + ) + elif generated.persisted is False: + text += " VIRTUAL" + return text + class OracleIdentifierPreparer(compiler.IdentifierPreparer): diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 5520389423..af61b5105b 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1873,6 +1873,9 @@ class PGDDLCompiler(compiler.DDLCompiler): if default is not None: colspec += " DEFAULT " + default + if column.computed is not None: + colspec += " " + self.process(column.computed) + if not column.nullable: colspec += " NOT NULL" return colspec @@ -2043,6 +2046,18 @@ class PGDDLCompiler(compiler.DDLCompiler): return "".join(table_opts) + def visit_computed_column(self, generated): + if generated.persisted is False: + raise exc.CompileError( + "PostrgreSQL computed columns do not support 'virtual' " + "persistence; set the 'persisted' flag to None or True for " + "PostgreSQL support." + ) + + return "GENERATED ALWAYS AS (%s) STORED" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_TSVECTOR(self, type_, **kw): diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 05efc42359..3fb340dadf 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1032,6 +1032,9 @@ class SQLiteCompiler(compiler.SQLCompiler): class SQLiteDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): + if column.computed is not None: + raise exc.CompileError("SQLite does not support computed columns") + coltype = self.dialect.type_compiler.process( column.type, type_expression=column ) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 6ddd12e60a..cb4db213b4 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -41,6 +41,7 @@ from .sql.schema import Column # noqa from .sql.schema import ColumnCollectionConstraint # noqa from .sql.schema import ColumnCollectionMixin # noqa from .sql.schema import ColumnDefault # noqa +from .sql.schema import Computed # noqa from .sql.schema import Constraint # noqa from .sql.schema import DefaultClause # noqa from .sql.schema import DefaultGenerator # noqa diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 810a4d9286..7a063ab1b5 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -3108,6 +3108,9 @@ class DDLCompiler(Compiled): if default is not None: colspec += " DEFAULT " + default + if column.computed is not None: + colspec += " " + self.process(column.computed) + if not column.nullable: colspec += " NOT NULL" return colspec @@ -3249,6 +3252,16 @@ class DDLCompiler(Compiled): text += " MATCH %s" % constraint.match return text + def visit_computed_column(self, generated): + text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + if generated.persisted is True: + text += " STORED" + elif generated.persisted is False: + text += " VIRTUAL" + return text + class GenericTypeCompiler(TypeCompiler): def visit_FLOAT(self, type_, **kw): diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index e3c3d96b96..c1a31e226b 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -1076,9 +1076,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): :class:`.SchemaItem` derived constructs which will be applied as options to the column. These include instances of :class:`.Constraint`, :class:`.ForeignKey`, :class:`.ColumnDefault`, - and :class:`.Sequence`. In some cases an equivalent keyword - argument is available such as ``server_default``, ``default`` - and ``unique``. + :class:`.Sequence`, :class:`.Computed`. In some cases an + equivalent keyword argument is available such as ``server_default``, + ``default`` and ``unique``. :param autoincrement: Set up "auto increment" semantics for an integer primary key column. The default value is the string ``"auto"`` @@ -1344,6 +1344,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): self.constraints = set() self.foreign_keys = set() self.comment = kwargs.pop("comment", None) + self.computed = None # check if this Column is proxying another column if "_proxies" in kwargs: @@ -1550,6 +1551,12 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): c.copy(**kw) for c in self.constraints if not c._type_bound ] + [c.copy(**kw) for c in self.foreign_keys if not c.constraint] + server_default = self.server_default + server_onupdate = self.server_onupdate + if isinstance(server_default, Computed): + server_default = server_onupdate = None + args.append(self.server_default.copy(**kw)) + type_ = self.type if isinstance(type_, SchemaEventTarget): type_ = type_.copy(**kw) @@ -1566,9 +1573,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): index=self.index, autoincrement=self.autoincrement, default=self.default, - server_default=self.server_default, + server_default=server_default, onupdate=self.onupdate, - server_onupdate=self.server_onupdate, + server_onupdate=server_onupdate, doc=self.doc, comment=self.comment, *args @@ -4454,3 +4461,89 @@ class _SchemaTranslateMap(object): _default_schema_map = _SchemaTranslateMap(None) _schema_getter = _SchemaTranslateMap._schema_getter + + +class Computed(FetchedValue, SchemaItem): + """Defines a generated column, i.e. "GENERATED ALWAYS AS" syntax. + + The :class:`.Computed` construct is an inline construct added to the + argument list of a :class:`.Column` object:: + + from sqlalchemy import Computed + + Table('square', meta, + Column('side', Float, nullable=False), + Column('area', Float, Computed('side * side')) + ) + + See the linked documentation below for complete details. + + .. versionadded:: 1.3.11 + + .. seealso:: + + :ref:`computed_ddl` + + """ + + __visit_name__ = "computed_column" + + @_document_text_coercion( + "sqltext", ":class:`.Computed`", ":paramref:`.Computed.sqltext`" + ) + def __init__(self, sqltext, persisted=None): + """Construct a GENERATED ALWAYS AS DDL construct to accompany a + :class:`.Column`. + + :param sqltext: + A string containing the column generation expression, which will be + used verbatim, or a SQL expression construct, such as a :func:`.text` + object. If given as a string, the object is converted to a + :func:`.text` object. + + :param persisted: + Optional, controls how this column should be persisted by the + database. Possible values are: + + * None, the default, it will use the default persistence defined + by the database. + * True, will render ``GENERATED ALWAYS AS ... STORED``, or the + equivalent for the target database if supported + * False, will render ``GENERATED ALWAYS AS ... VIRTUAL``, or the + equivalent for the target database if supported. + + Specifying ``True`` or ``False`` may raise an error when the DDL + is emitted to the target database if the databse does not support + that persistence option. Leaving this parameter at its default + of ``None`` is guaranteed to succeed for all databases that support + ``GENERATED ALWAYS AS``. + + """ + self.sqltext = _literal_as_text(sqltext, allow_coercion_to_text=True) + self.persisted = persisted + self.column = None + + def _set_parent(self, parent): + if not isinstance( + parent.server_default, (type(None), Computed) + ) or not isinstance(parent.server_onupdate, (type(None), Computed)): + raise exc.ArgumentError( + "A generated column cannot specify a server_default or a " + "server_onupdate argument" + ) + self.column = parent + parent.computed = self + self.column.server_onupdate = self + self.column.server_default = self + + def _as_for_update(self, for_update): + return self + + def copy(self, target_table=None, **kw): + if target_table is not None: + sqltext = _copy_expression(self.sqltext, self.table, target_table) + else: + sqltext = self.sqltext + g = Computed(sqltext, persisted=self.persisted) + + return self._schema_item_copy(g) diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 87bbc6a0f2..8262142ec5 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -44,7 +44,7 @@ def combinations(*comb, **kw): well as if it is included in the tokens used to create the id of the parameter set. - If omitted, the argment combinations are passed to parametrize as is. If + If omitted, the argument combinations are passed to parametrize as is. If passed, each argument combination is turned into a pytest.param() object, mapping the elements of the argument tuple to produce an id based on a character value in the same position within the string template using the @@ -59,9 +59,12 @@ def combinations(*comb, **kw): r - the given argument should be passed and it should be added to the id by calling repr() - s- the given argument should be passed and it should be added to the + s - the given argument should be passed and it should be added to the id by calling str() + a - (argument) the given argument should be passed and it should not + be used to generated the id + e.g.:: @testing.combinations( diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 266aa33272..e47f6829f6 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1021,3 +1021,7 @@ class SuiteRequirements(Requirements): return True except ImportError: return False + + @property + def computed_columns(self): + return exclusions.closed() diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index d265851419..1253ca81cd 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -6,6 +6,7 @@ from ..schema import Column from ..schema import Table from ... import bindparam from ... import case +from ... import Computed from ... import false from ... import func from ... import Integer @@ -14,6 +15,7 @@ from ... import null from ... import select from ... import String from ... import testing +from ... import text from ... import true from ... import tuple_ from ... import union @@ -656,3 +658,47 @@ class LikeFunctionsTest(fixtures.TablesTest): col = self.tables.some_table.c.data self._test(col.contains("b%cd", autoescape=True, escape="#"), {3}) self._test(col.contains("b#cd", autoescape=True, escape="#"), {7}) + + +class ComputedColumnTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("computed_columns",) + + @classmethod + def define_tables(cls, metadata): + Table( + "square", + metadata, + Column("id", Integer, primary_key=True), + Column("side", Integer), + Column("area", Integer, Computed("side * side")), + Column("perimeter", Integer, Computed("4 * side")), + ) + + @classmethod + def insert_data(cls): + with config.db.begin() as conn: + conn.execute( + cls.tables.square.insert(), + [{"id": 1, "side": 10}, {"id": 10, "side": 42}], + ) + + def test_select_all(self): + with config.db.connect() as conn: + res = conn.execute( + select([text("*")]) + .select_from(self.tables.square) + .order_by(self.tables.square.c.id) + ).fetchall() + eq_(res, [(1, 10, 100, 40), (10, 42, 1764, 168)]) + + def test_select_columns(self): + with config.db.connect() as conn: + res = conn.execute( + select( + [self.tables.square.c.area, self.tables.square.c.perimeter] + ) + .select_from(self.tables.square) + .order_by(self.tables.square.c.id) + ).fetchall() + eq_(res, [(100, 40), (1764, 168)]) diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index 4fba61dfe2..eb7ce0ac3d 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -1,5 +1,6 @@ # -*- encoding: utf-8 from sqlalchemy import Column +from sqlalchemy import Computed from sqlalchemy import delete from sqlalchemy import extract from sqlalchemy import func @@ -1120,7 +1121,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): idx = Index("test_idx_data_1", tbl.c.data, mssql_where=tbl.c.data > 1) self.assert_compile( schema.CreateIndex(idx), - "CREATE INDEX test_idx_data_1 ON test (data) WHERE data > 1" + "CREATE INDEX test_idx_data_1 ON test (data) WHERE data > 1", ) def test_index_ordering(self): @@ -1190,6 +1191,27 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT TRY_CAST (t1.id AS INTEGER) AS anon_1 FROM t1", ) + @testing.combinations( + ("no_persisted", "", "ignore"), + ("persisted_none", "", None), + ("persisted_true", " PERSISTED", True), + ("persisted_false", "", False), + id_="iaa", + ) + def test_column_computed(self, text, persisted): + m = MetaData() + kwargs = {"persisted": persisted} if persisted != "ignore" else {} + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", **kwargs)), + ) + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE t (x INTEGER NULL, y AS (x + 2)%s)" % text, + ) + class SchemaTest(fixtures.TestBase): def setup(self): diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 301562d1c6..d59c0549f1 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -8,6 +8,7 @@ from sqlalchemy import CHAR from sqlalchemy import CheckConstraint from sqlalchemy import CLOB from sqlalchemy import Column +from sqlalchemy import Computed from sqlalchemy import DATE from sqlalchemy import Date from sqlalchemy import DATETIME @@ -386,6 +387,28 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_compile(sql.delete(a1), "DELETE FROM t1 AS a1") + @testing.combinations( + ("no_persisted", "", "ignore"), + ("persisted_none", "", None), + ("persisted_true", " STORED", True), + ("persisted_false", " VIRTUAL", False), + id_="iaa", + ) + def test_column_computed(self, text, persisted): + m = MetaData() + kwargs = {"persisted": persisted} if persisted != "ignore" else {} + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", **kwargs)), + ) + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE t (x INTEGER, y INTEGER GENERATED " + "ALWAYS AS (x + 2)%s)" % text, + ) + class SQLTest(fixtures.TestBase, AssertsCompiledSQL): diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py index 981a5f9c73..ab7f471980 100644 --- a/test/dialect/oracle/test_compiler.py +++ b/test/dialect/oracle/test_compiler.py @@ -3,6 +3,7 @@ from sqlalchemy import and_ from sqlalchemy import bindparam +from sqlalchemy import Computed from sqlalchemy import exc from sqlalchemy import except_ from sqlalchemy import ForeignKey @@ -18,6 +19,7 @@ from sqlalchemy import select from sqlalchemy import Sequence from sqlalchemy import sql from sqlalchemy import String +from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import type_coerce from sqlalchemy import TypeDecorator @@ -908,6 +910,40 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "t1.c2, t1.c3 INTO :ret_0, :ret_1", ) + def test_returning_insert_computed(self): + m = MetaData() + t1 = Table( + "t1", + m, + Column("id", Integer, primary_key=True), + Column("foo", Integer), + Column("bar", Integer, Computed("foo + 42")), + ) + + self.assert_compile( + t1.insert().values(id=1, foo=5).returning(t1.c.bar), + "INSERT INTO t1 (id, foo) VALUES (:id, :foo) " + "RETURNING t1.bar INTO :ret_0", + ) + + def test_returning_update_computed_warning(self): + m = MetaData() + t1 = Table( + "t1", + m, + Column("id", Integer, primary_key=True), + Column("foo", Integer), + Column("bar", Integer, Computed("foo + 42")), + ) + + with testing.expect_warnings( + "Computed columns don't work with Oracle UPDATE" + ): + self.assert_compile( + t1.update().values(id=1, foo=5).returning(t1.c.bar), + "UPDATE t1 SET id=:id, foo=:foo RETURNING t1.bar INTO :ret_0", + ) + def test_compound(self): t1 = table("t1", column("c1"), column("c2"), column("c3")) t2 = table("t2", column("c1"), column("c2"), column("c3")) @@ -1006,6 +1042,42 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "CREATE BITMAP INDEX idx3 ON testtbl (data)", ) + @testing.combinations( + ("no_persisted", "", "ignore"), + ("persisted_none", "", None), + ("persisted_false", " VIRTUAL", False), + id_="iaa", + ) + def test_column_computed(self, text, persisted): + m = MetaData() + kwargs = {"persisted": persisted} if persisted != "ignore" else {} + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", **kwargs)), + ) + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE t (x INTEGER, y INTEGER GENERATED " + "ALWAYS AS (x + 2)%s)" % text, + ) + + def test_column_computed_persisted_true(self): + m = MetaData() + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", persisted=True)), + ) + assert_raises_message( + exc.CompileError, + r".*Oracle computed columns do not support 'stored' ", + schema.CreateTable(t).compile, + dialect=oracle.dialect(), + ) + class SequenceTest(fixtures.TestBase, AssertsCompiledSQL): def test_basic(self): diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index 5966731408..e226ca7fbb 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -3,6 +3,7 @@ import re from sqlalchemy import bindparam +from sqlalchemy import Computed from sqlalchemy import create_engine from sqlalchemy import exc from sqlalchemy import Float @@ -258,6 +259,72 @@ class EncodingErrorsTest(fixtures.TestBase): ) +class ComputedReturningTest(fixtures.TablesTest): + __only_on__ = "oracle" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "test", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer), + Column("bar", Integer, Computed("foo + 42")), + ) + + Table( + "test_no_returning", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer), + Column("bar", Integer, Computed("foo + 42")), + implicit_returning=False, + ) + + def test_computed_insert(self): + test = self.tables.test + with testing.db.connect() as conn: + result = conn.execute( + test.insert().return_defaults(), {"id": 1, "foo": 5} + ) + + eq_(result.returned_defaults, (47,)) + + eq_(conn.scalar(select([test.c.bar])), 47) + + def test_computed_update_warning(self): + test = self.tables.test + with testing.db.connect() as conn: + conn.execute(test.insert(), {"id": 1, "foo": 5}) + + with testing.expect_warnings( + "Computed columns don't work with Oracle UPDATE" + ): + result = conn.execute( + test.update().values(foo=10).return_defaults() + ) + + # returns the *old* value + eq_(result.returned_defaults, (47,)) + + eq_(conn.scalar(select([test.c.bar])), 52) + + def test_computed_update_no_warning(self): + test = self.tables.test_no_returning + with testing.db.connect() as conn: + conn.execute(test.insert(), {"id": 1, "foo": 5}) + + result = conn.execute( + test.update().values(foo=10).return_defaults() + ) + + # no returning + eq_(result.returned_defaults, None) + + eq_(conn.scalar(select([test.c.bar])), 52) + + class OutParamTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "oracle+cx_oracle" __backend__ = True diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 83e3ee3fd2..4c4c43281e 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -3,6 +3,7 @@ from sqlalchemy import and_ from sqlalchemy import cast from sqlalchemy import Column +from sqlalchemy import Computed from sqlalchemy import delete from sqlalchemy import Enum from sqlalchemy import exc @@ -1541,6 +1542,42 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): q, "DELETE FROM t1 AS a1 USING t2 WHERE a1.c1 = t2.c1" ) + @testing.combinations( + ("no_persisted", " STORED", "ignore"), + ("persisted_none", " STORED", None), + ("persisted_true", " STORED", True), + id_="iaa", + ) + def test_column_computed(self, text, persisted): + m = MetaData() + kwargs = {"persisted": persisted} if persisted != "ignore" else {} + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", **kwargs)), + ) + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE t (x INTEGER, y INTEGER GENERATED " + "ALWAYS AS (x + 2)%s)" % text, + ) + + def test_column_computed_persisted_false(self): + m = MetaData() + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", persisted=False)), + ) + assert_raises_message( + exc.CompileError, + "PostrgreSQL computed columns do not support 'virtual'", + schema.CreateTable(t).compile, + dialect=postgresql.dialect(), + ) + class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = postgresql.dialect() diff --git a/test/dialect/test_firebird.py b/test/dialect/test_firebird.py index f1ce321e3c..ad146bc77e 100644 --- a/test/dialect/test_firebird.py +++ b/test/dialect/test_firebird.py @@ -1,4 +1,5 @@ from sqlalchemy import Column +from sqlalchemy import Computed from sqlalchemy import exc from sqlalchemy import Float from sqlalchemy import func @@ -438,6 +439,42 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(column("_somecol"), '"_somecol"') self.assert_compile(column("$somecol"), '"$somecol"') + @testing.combinations( + ("no_persisted", "ignore"), ("persisted_none", None), id_="ia" + ) + def test_column_computed(self, persisted): + m = MetaData() + kwargs = {"persisted": persisted} if persisted != "ignore" else {} + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", **kwargs)), + ) + self.assert_compile( + schema.CreateTable(t), + "CREATE TABLE t (x INTEGER, y INTEGER GENERATED " + "ALWAYS AS (x + 2))", + ) + + @testing.combinations( + ("persisted_true", True), ("persisted_false", False), id_="ia" + ) + def test_column_computed_raises(self, persisted): + m = MetaData() + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", persisted=persisted)), + ) + assert_raises_message( + exc.CompileError, + "Firebird computed columns do not support a persistence method", + schema.CreateTable(t).compile, + dialect=firebird.dialect(), + ) + class TypesTest(fixtures.TestBase): __only_on__ = "firebird" diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index 8694a45b81..931631308c 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -10,6 +10,7 @@ from sqlalchemy import bindparam from sqlalchemy import CheckConstraint from sqlalchemy import Column from sqlalchemy import column +from sqlalchemy import Computed from sqlalchemy import create_engine from sqlalchemy import DefaultClause from sqlalchemy import event @@ -745,6 +746,29 @@ class DialectTest(fixtures.TestBase, AssertsExecutionResults): url = make_url(url) eq_(d.create_connect_args(url), expected) + @testing.combinations( + ("no_persisted", "ignore"), + ("persisted_none", None), + ("persisted_true", True), + ("persisted_false", False), + id_="ia", + ) + def test_column_computed(self, persisted): + m = MetaData() + kwargs = {"persisted": persisted} if persisted != "ignore" else {} + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", **kwargs)), + ) + assert_raises_message( + exc.CompileError, + "SQLite does not support computed columns", + schema.CreateTable(t).compile, + dialect=sqlite.dialect(), + ) + class AttachedDBTest(fixtures.TestBase): __only_on__ = "sqlite" diff --git a/test/orm/test_defaults.py b/test/orm/test_defaults.py index 5fbe091be0..9f37dbf4da 100644 --- a/test/orm/test_defaults.py +++ b/test/orm/test_defaults.py @@ -1,11 +1,16 @@ import sqlalchemy as sa +from sqlalchemy import Computed from sqlalchemy import event from sqlalchemy import Integer from sqlalchemy import String +from sqlalchemy import testing from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper +from sqlalchemy.orm import Session from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.assertsql import assert_engine +from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -170,3 +175,173 @@ class ExcludedDefaultsTest(fixtures.MappedTest): sess.add(f1) sess.flush() eq_(dt.select().execute().fetchall(), [(1, "hello")]) + + +class ComputedDefaultsOnUpdateTest(fixtures.MappedTest): + """test that computed columns are recognized as server + oninsert/onupdate defaults.""" + + __backend__ = True + __requires__ = ("computed_columns",) + + @classmethod + def define_tables(cls, metadata): + Table( + "test", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer), + Column("bar", Integer, Computed("foo + 42")), + ) + + @classmethod + def setup_classes(cls): + class Thing(cls.Basic): + pass + + class ThingNoEager(cls.Basic): + pass + + @classmethod + def setup_mappers(cls): + Thing = cls.classes.Thing + + mapper(Thing, cls.tables.test, eager_defaults=True) + + ThingNoEager = cls.classes.ThingNoEager + mapper(ThingNoEager, cls.tables.test, eager_defaults=False) + + @testing.combinations(("eager", True), ("noneager", False), id_="ia") + def test_insert_computed(self, eager): + if eager: + Thing = self.classes.Thing + else: + Thing = self.classes.ThingNoEager + + s = Session() + + t1, t2 = (Thing(id=1, foo=5), Thing(id=2, foo=10)) + + s.add_all([t1, t2]) + + with assert_engine(testing.db) as asserter: + s.flush() + eq_(t1.bar, 5 + 42) + eq_(t2.bar, 10 + 42) + + if eager and testing.db.dialect.implicit_returning: + asserter.assert_( + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (%(id)s, %(foo)s) " + "RETURNING test.bar", + [{"foo": 5, "id": 1}], + dialect="postgresql", + ), + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (%(id)s, %(foo)s) " + "RETURNING test.bar", + [{"foo": 10, "id": 2}], + dialect="postgresql", + ), + ) + else: + asserter.assert_( + CompiledSQL( + "INSERT INTO test (id, foo) VALUES (:id, :foo)", + [{"foo": 5, "id": 1}, {"foo": 10, "id": 2}], + ), + CompiledSQL( + "SELECT test.bar AS test_bar FROM test " + "WHERE test.id = :param_1", + [{"param_1": 1}], + ), + CompiledSQL( + "SELECT test.bar AS test_bar FROM test " + "WHERE test.id = :param_1", + [{"param_1": 2}], + ), + ) + + @testing.requires.computed_columns_on_update_returning + def test_update_computed_eager(self): + self._test_update_computed(True) + + def test_update_computed_noneager(self): + self._test_update_computed(False) + + def _test_update_computed(self, eager): + if eager: + Thing = self.classes.Thing + else: + Thing = self.classes.ThingNoEager + + s = Session() + + t1, t2 = (Thing(id=1, foo=1), Thing(id=2, foo=2)) + + s.add_all([t1, t2]) + s.flush() + + t1.foo = 5 + t2.foo = 6 + + with assert_engine(testing.db) as asserter: + s.flush() + eq_(t1.bar, 5 + 42) + eq_(t2.bar, 6 + 42) + + if eager and testing.db.dialect.implicit_returning: + asserter.assert_( + CompiledSQL( + "UPDATE test SET foo=%(foo)s " + "WHERE test.id = %(test_id)s " + "RETURNING test.bar", + [{"foo": 5, "test_id": 1}], + dialect="postgresql", + ), + CompiledSQL( + "UPDATE test SET foo=%(foo)s " + "WHERE test.id = %(test_id)s " + "RETURNING test.bar", + [{"foo": 6, "test_id": 2}], + dialect="postgresql", + ), + ) + elif eager: + asserter.assert_( + CompiledSQL( + "UPDATE test SET foo=:foo WHERE test.id = :test_id", + [{"foo": 5, "test_id": 1}], + ), + CompiledSQL( + "UPDATE test SET foo=:foo WHERE test.id = :test_id", + [{"foo": 6, "test_id": 2}], + ), + CompiledSQL( + "SELECT test.bar AS test_bar FROM test " + "WHERE test.id = :param_1", + [{"param_1": 1}], + ), + CompiledSQL( + "SELECT test.bar AS test_bar FROM test " + "WHERE test.id = :param_1", + [{"param_1": 2}], + ), + ) + else: + asserter.assert_( + CompiledSQL( + "UPDATE test SET foo=:foo WHERE test.id = :test_id", + [{"foo": 5, "test_id": 1}, {"foo": 6, "test_id": 2}], + ), + CompiledSQL( + "SELECT test.bar AS test_bar FROM test " + "WHERE test.id = :param_1", + [{"param_1": 1}], + ), + CompiledSQL( + "SELECT test.bar AS test_bar FROM test " + "WHERE test.id = :param_1", + [{"param_1": 2}], + ), + ) diff --git a/test/requirements.py b/test/requirements.py index ca0432ecdf..1e7c9be7e1 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -358,6 +358,10 @@ class DefaultRequirements(SuiteRequirements): def sql_expressions_inserted_as_primary_key(self): return only_if([self.returning, self.sqlite]) + @property + def computed_columns_on_update_returning(self): + return self.computed_columns + skip_if("oracle") + @property def correlated_outer_joins(self): """Target must support an outer join to a subquery which @@ -760,8 +764,9 @@ class DefaultRequirements(SuiteRequirements): @property def nullsordering(self): """Target backends that support nulls ordering.""" - return fails_on_everything_except("postgresql", "oracle", "firebird", - "sqlite >= 3.30.0") + return fails_on_everything_except( + "postgresql", "oracle", "firebird", "sqlite >= 3.30.0" + ) @property def reflects_pk_names(self): @@ -1402,3 +1407,7 @@ class DefaultRequirements(SuiteRequirements): lambda config: against(config, "oracle+cx_oracle") and config.db.dialect.cx_oracle_ver < (6,) ) + + @property + def computed_columns(self): + return skip_if(["postgresql < 12", "sqlite", "mysql < 5.7"]) diff --git a/test/sql/test_computed.py b/test/sql/test_computed.py new file mode 100644 index 0000000000..2999c621cb --- /dev/null +++ b/test/sql/test_computed.py @@ -0,0 +1,80 @@ +# coding: utf-8 +from sqlalchemy import Column +from sqlalchemy import Computed +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import Table +from sqlalchemy.exc import ArgumentError +from sqlalchemy.schema import CreateTable +from sqlalchemy.testing import assert_raises_message +from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import combinations +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_not_ + + +class DDLComputedTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = "default" + + @combinations( + ("no_persisted", "", "ignore"), + ("persisted_none", "", None), + ("persisted_true", " STORED", True), + ("persisted_false", " VIRTUAL", False), + id_="iaa", + ) + def test_column_computed(self, text, persisted): + m = MetaData() + kwargs = {"persisted": persisted} if persisted != "ignore" else {} + t = Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2", **kwargs)), + ) + self.assert_compile( + CreateTable(t), + "CREATE TABLE t (x INTEGER, y INTEGER GENERATED " + "ALWAYS AS (x + 2)%s)" % text, + ) + + def test_server_default_onupdate(self): + text = ( + "A generated column cannot specify a server_default or a " + "server_onupdate argument" + ) + + def fn(**kwargs): + m = MetaData() + Table( + "t", + m, + Column("x", Integer), + Column("y", Integer, Computed("x + 2"), **kwargs), + ) + + assert_raises_message(ArgumentError, text, fn, server_default="42") + assert_raises_message(ArgumentError, text, fn, server_onupdate="42") + + def test_tometadata(self): + comp1 = Computed("x + 2") + m = MetaData() + t = Table("t", m, Column("x", Integer), Column("y", Integer, comp1)) + is_(comp1.column, t.c.y) + is_(t.c.y.server_onupdate, comp1) + is_(t.c.y.server_default, comp1) + + m2 = MetaData() + t2 = t.tometadata(m2) + comp2 = t2.c.y.server_default + + is_not_(comp1, comp2) + + is_(comp1.column, t.c.y) + is_(t.c.y.server_onupdate, comp1) + is_(t.c.y.server_default, comp1) + + is_(comp2.column, t2.c.y) + is_(t2.c.y.server_onupdate, comp2) + is_(t2.c.y.server_default, comp2) diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 2e8d4deebc..11ed9d8de5 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -652,6 +652,8 @@ class MetaDataTest(fixtures.TestBase, ComparesTables): class ToMetaDataTest(fixtures.TestBase, ComparesTables): @testing.requires.check_constraints def test_copy(self): + # TODO: modernize this test + from sqlalchemy.testing.schema import Table meta = MetaData()