From: Mike Bayer Date: Mon, 11 Feb 2019 22:00:47 +0000 (-0500) Subject: Allow SQL expression for ORM primary keys X-Git-Tag: rel_1_3_0~19^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c1f310df44033d943413170de878ce95fafa387e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Allow SQL expression for ORM primary keys A SQL expression can now be assigned to a primary key attribute for an ORM flush in the same manner as ordinary attributes as described in :ref:`flush_embedded_sql_expressions` where the expression will be evaulated and then returned to the ORM using RETURNING, or in the case of pysqlite, works using the cursor.lastrowid attribute.Requires either a database that supports RETURNING (e.g. Postgresql, Oracle, SQL Server) or pysqlite. Fixes: #3133 Fixes: #4494 Change-Id: I83da8357354de002cb04fa4a553f2a2f90c5157d --- diff --git a/doc/build/changelog/unreleased_13/3133.rst b/doc/build/changelog/unreleased_13/3133.rst new file mode 100644 index 0000000000..c163e5296d --- /dev/null +++ b/doc/build/changelog/unreleased_13/3133.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: feature, orm + :tickets: 3133 + + A SQL expression can now be assigned to a primary key attribute for an ORM + flush in the same manner as ordinary attributes as described in + :ref:`flush_embedded_sql_expressions` where the expression will be evaulated + and then returned to the ORM using RETURNING, or in the case of pysqlite, + works using the cursor.lastrowid attribute.Requires either a database that + supports RETURNING (e.g. Postgresql, Oracle, SQL Server) or pysqlite. diff --git a/doc/build/orm/persistence_techniques.rst b/doc/build/orm/persistence_techniques.rst index a26be6b4c0..0a40e77956 100644 --- a/doc/build/orm/persistence_techniques.rst +++ b/doc/build/orm/persistence_techniques.rst @@ -32,6 +32,44 @@ flush/commit operation, the ``value`` attribute on ``someobject`` above is expired, so that when next accessed the newly generated value will be loaded from the database. +The feature also has conditional support to work in conjunction with +primary key columns. A database that supports RETURNING, e.g. PostgreSQL, +Oracle, or SQL Server, or as a special case when using SQLite with the pysqlite +driver and a single auto-increment column, a SQL expression may be assigned +to a primary key column as well. This allows both the SQL expression to +be evaluated, as well as allows any server side triggers that modify the +primary key value on INSERT, to be successfully retrieved by the ORM as +part of the object's primary key:: + + + class Foo(Base): + __tablename__ = 'foo' + pk = Column(Integer, primary_key=True) + bar = Column(Integer) + + e = create_engine("postgresql://scott:tiger@localhost/test", echo=True) + Base.metadata.create_all(e) + + session = Session(e) + + foo = Foo(pk=sql.select([sql.func.coalesce(sql.func.max(Foo.pk) + 1, 1)]) + session.add(foo) + session.commit() + +On PostgreSQL, the above :class:`.Session` will emit the following INSERT: + +.. sourcecode:: sql + + INSERT INTO foo (foopk, bar) VALUES + ((SELECT coalesce(max(foo.foopk) + %(max_1)s, %(coalesce_2)s) AS coalesce_1 + FROM foo), %(bar)s) RETURNING foo.foopk + +.. versionadded:: 1.3 + SQL expressions can now be passed to a primary key column during an ORM + flush; if the database supports RETURNING, or if pysqlite is in use, the + ORM will be able to retrieve the server-generated value as the value + of the primary key attribute. + .. _session_sql_expressions: Using SQL Expressions with Sessions diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index c90c8d91ed..6345ee28a5 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -507,7 +507,7 @@ def _collect_insert_commands( and hasattr(value, "__clause_element__") or isinstance(value, sql.ClauseElement) ): - value_params[col.key] = ( + value_params[col] = ( value.__clause_element__() if hasattr(value, "__clause_element__") else value @@ -525,7 +525,7 @@ def _collect_insert_commands( for colkey in ( mapper._insert_cols_as_none[table] .difference(params) - .difference(value_params) + .difference([c.key for c in value_params]) ): params[colkey] = None @@ -932,6 +932,7 @@ def _emit_update_statements( c, c.context.compiled_parameters[0], value_params, + True, ) rows += c.rowcount check_rowcount = assert_singlerow @@ -963,6 +964,7 @@ def _emit_update_statements( c, c.context.compiled_parameters[0], value_params, + True, ) rows += c.rowcount else: @@ -998,6 +1000,7 @@ def _emit_update_statements( c, c.context.compiled_parameters[0], value_params, + True, ) if check_rowcount: @@ -1086,6 +1089,7 @@ def _emit_insert_statements( c, last_inserted_params, value_params, + False, ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) @@ -1117,14 +1121,16 @@ def _emit_insert_statements( ) primary_key = result.context.inserted_primary_key - if primary_key is not None: # set primary key attributes for pk, col in zip( primary_key, mapper._pks_by_table[table] ): prop = mapper_rec._columntoproperty[col] - if state_dict.get(prop.key) is None: + if pk is not None and ( + col in value_params + or state_dict.get(prop.key) is None + ): state_dict[prop.key] = pk if bookkeeping: if state: @@ -1137,6 +1143,7 @@ def _emit_insert_statements( result, result.context.compiled_parameters[0], value_params, + False, ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) @@ -1461,7 +1468,15 @@ def _postfetch_post_update( def _postfetch( - mapper, uowtransaction, table, state, dict_, result, params, value_params + mapper, + uowtransaction, + table, + state, + dict_, + result, + params, + value_params, + isupdate, ): """Expire attributes in need of newly persisted database state, after an INSERT or UPDATE statement has proceeded for that @@ -1511,6 +1526,18 @@ def _postfetch( state, uowtransaction, load_evt_attrs ) + if isupdate and value_params: + # explicitly suit the use case specified by + # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING + # database which are set to themselves in order to do a version bump. + postfetch_cols.extend( + [ + col + for col in value_params + if col.primary_key and col not in returning_cols + ] + ) + if postfetch_cols: state._expire_attributes( state.dict, diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index cc72073ca2..6c9b8ee5bc 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -409,7 +409,12 @@ def _append_param_parameter( compiler.returning.append(c) value = compiler.process(value.self_group(), **kw) else: - compiler.postfetch.append(c) + # postfetch specifically means, "we can SELECT the row we just + # inserted by primary key to get back the server generated + # defaults". so by definition this can't be used to get the primary + # key value back, because we need to have it ahead of time. + if not c.primary_key: + compiler.postfetch.append(c) value = compiler.process(value.self_group(), **kw) values.append((c, value)) diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index 6326f5f1a5..dc8a818d24 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -10,6 +10,7 @@ from sqlalchemy import event from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer +from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import select from sqlalchemy import String @@ -499,6 +500,19 @@ class ClauseAttributesTest(fixtures.MappedTest): Column("value", Boolean), ) + Table( + "pk_t", + metadata, + Column( + "p_id", + Integer, + key="id", + autoincrement=True, + primary_key=True, + ), + Column("data", String(30)), + ) + @classmethod def setup_classes(cls): class User(cls.Comparable): @@ -507,12 +521,17 @@ class ClauseAttributesTest(fixtures.MappedTest): class HasBoolean(cls.Comparable): pass + class PkDefault(cls.Comparable): + pass + @classmethod def setup_mappers(cls): User, users_t = cls.classes.User, cls.tables.users_t HasBoolean, boolean_t = cls.classes.HasBoolean, cls.tables.boolean_t + PkDefault, pk_t = cls.classes.PkDefault, cls.tables.pk_t mapper(User, users_t) mapper(HasBoolean, boolean_t) + mapper(PkDefault, pk_t) def test_update(self): User = self.classes.User @@ -568,6 +587,19 @@ class ClauseAttributesTest(fixtures.MappedTest): assert (u.counter == 5) is True + @testing.requires.sql_expressions_inserted_as_primary_key + def test_insert_pk_expression(self): + PkDefault = self.classes.PkDefault + + pk = PkDefault(id=literal(5) + 10, data="some data") + session = Session() + session.add(pk) + session.flush() + + eq_(pk.id, 15) + session.commit() + eq_(pk.id, 15) + def test_update_special_comparator(self): HasBoolean = self.classes.HasBoolean diff --git a/test/requirements.py b/test/requirements.py index c7c7206fcd..ab3f01d04a 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -349,6 +349,10 @@ class DefaultRequirements(SuiteRequirements): "postgresql", "doesn't support sequences as a server side default." ) + @property + def sql_expressions_inserted_as_primary_key(self): + return only_if([self.returning, self.sqlite]) + @property def correlated_outer_joins(self): """Target must support an outer join to a subquery which diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index 066e54e93c..cf6715a0e6 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -16,6 +16,7 @@ from sqlalchemy import table from sqlalchemy import text from sqlalchemy.dialects import mysql from sqlalchemy.dialects import postgresql +from sqlalchemy.dialects import sqlite from sqlalchemy.engine import default from sqlalchemy.sql import crud from sqlalchemy.testing import assert_raises @@ -1196,6 +1197,96 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): dialect=postgresql.dialect(), ) + def test_sql_expression_pk_autoinc_lastinserted(self): + # test that postfetch isn't invoked for a SQL expression + # in a primary key column. the DB either needs to support a lastrowid + # that can return it, or RETURNING. [ticket:3133] + metadata = MetaData() + table = Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String), + ) + + stmt = table.insert().return_defaults().values(id=func.foobar()) + compiled = stmt.compile(dialect=sqlite.dialect(), column_keys=["data"]) + eq_(compiled.postfetch, []) + eq_(compiled.returning, []) + + self.assert_compile( + stmt, + "INSERT INTO sometable (id, data) VALUES " "(foobar(), ?)", + checkparams={"data": "foo"}, + params={"data": "foo"}, + dialect=sqlite.dialect(), + ) + + def test_sql_expression_pk_autoinc_returning(self): + # test that return_defaults() works with a primary key where we are + # sending a SQL expression, and we need to get the server-calculated + # value back. [ticket:3133] + metadata = MetaData() + table = Table( + "sometable", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String), + ) + + stmt = table.insert().return_defaults().values(id=func.foobar()) + returning_dialect = postgresql.dialect() + returning_dialect.implicit_returning = True + compiled = stmt.compile( + dialect=returning_dialect, column_keys=["data"] + ) + eq_(compiled.postfetch, []) + eq_(compiled.returning, [table.c.id]) + + self.assert_compile( + stmt, + "INSERT INTO sometable (id, data) VALUES " + "(foobar(), %(data)s) RETURNING sometable.id", + checkparams={"data": "foo"}, + params={"data": "foo"}, + dialect=returning_dialect, + ) + + def test_sql_expression_pk_noautoinc_returning(self): + # test that return_defaults() works with a primary key where we are + # sending a SQL expression, and we need to get the server-calculated + # value back. [ticket:3133] + metadata = MetaData() + table = Table( + "sometable", + metadata, + Column( + "id", + Integer, + autoincrement=False, + primary_key=True, + ), + Column("data", String), + ) + + stmt = table.insert().return_defaults().values(id=func.foobar()) + returning_dialect = postgresql.dialect() + returning_dialect.implicit_returning = True + compiled = stmt.compile( + dialect=returning_dialect, column_keys=["data"] + ) + eq_(compiled.postfetch, []) + eq_(compiled.returning, [table.c.id]) + + self.assert_compile( + stmt, + "INSERT INTO sometable (id, data) VALUES " + "(foobar(), %(data)s) RETURNING sometable.id", + checkparams={"data": "foo"}, + params={"data": "foo"}, + dialect=returning_dialect, + ) + def test_python_fn_default(self): metadata = MetaData() table = Table(