]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allow SQL expression for ORM primary keys
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 11 Feb 2019 22:00:47 +0000 (17:00 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Feb 2019 21:55:48 +0000 (16:55 -0500)
A SQL expression can now be assigned to a primary key attribute for an ORM
flush in the same manner as ordinary attributes as described in
:ref:`flush_embedded_sql_expressions` where the expression will be evaulated
and then returned to the ORM using RETURNING, or in the case of pysqlite,
works using the cursor.lastrowid attribute.Requires either a database that
supports RETURNING (e.g. Postgresql, Oracle, SQL Server) or pysqlite.

Fixes: #3133
Fixes: #4494
Change-Id: I83da8357354de002cb04fa4a553f2a2f90c5157d

doc/build/changelog/unreleased_13/3133.rst [new file with mode: 0644]
doc/build/orm/persistence_techniques.rst
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/sql/crud.py
test/orm/test_unitofwork.py
test/requirements.py
test/sql/test_insert.py

diff --git a/doc/build/changelog/unreleased_13/3133.rst b/doc/build/changelog/unreleased_13/3133.rst
new file mode 100644 (file)
index 0000000..c163e52
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+   :tags: feature, orm
+   :tickets: 3133
+
+   A SQL expression can now be assigned to a primary key attribute for an ORM
+   flush in the same manner as ordinary attributes as described in
+   :ref:`flush_embedded_sql_expressions` where the expression will be evaulated
+   and then returned to the ORM using RETURNING, or in the case of pysqlite,
+   works using the cursor.lastrowid attribute.Requires either a database that
+   supports RETURNING (e.g. Postgresql, Oracle, SQL Server) or pysqlite.
index a26be6b4c0ceee55e7e0b2f47279c96df658464e..0a40e77956cc3352d3838252aee5329f1b8fd1af 100644 (file)
@@ -32,6 +32,44 @@ flush/commit operation, the ``value`` attribute on ``someobject`` above is
 expired, so that when next accessed the newly generated value will be loaded
 from the database.
 
+The feature also has conditional support to work in conjunction with
+primary key columns.  A database that supports RETURNING, e.g. PostgreSQL,
+Oracle, or SQL Server, or as a special case when using SQLite with the pysqlite
+driver and a single auto-increment column, a SQL expression may be assigned
+to a primary key column as well.  This allows both the SQL expression to
+be evaluated, as well as allows any server side triggers that modify the
+primary key value on INSERT, to be successfully retrieved by the ORM as
+part of the object's primary key::
+
+
+    class Foo(Base):
+        __tablename__ = 'foo'
+        pk = Column(Integer, primary_key=True)
+        bar = Column(Integer)
+
+    e = create_engine("postgresql://scott:tiger@localhost/test", echo=True)
+    Base.metadata.create_all(e)
+
+    session = Session(e)
+
+    foo = Foo(pk=sql.select([sql.func.coalesce(sql.func.max(Foo.pk) + 1, 1)])
+    session.add(foo)
+    session.commit()
+
+On PostgreSQL, the above :class:`.Session` will emit the following INSERT:
+
+.. sourcecode:: sql
+
+    INSERT INTO foo (foopk, bar) VALUES
+    ((SELECT coalesce(max(foo.foopk) + %(max_1)s, %(coalesce_2)s) AS coalesce_1
+    FROM foo), %(bar)s) RETURNING foo.foopk
+
+.. versionadded:: 1.3
+    SQL expressions can now be passed to a primary key column during an ORM
+    flush; if the database supports RETURNING, or if pysqlite is in use, the
+    ORM will be able to retrieve the server-generated value as the value
+    of the primary key attribute.
+
 .. _session_sql_expressions:
 
 Using SQL Expressions with Sessions
index c90c8d91ed29b0d3db889c757120042216400b76..6345ee28a5ccffaf81aec4583340a051c7652a2f 100644 (file)
@@ -507,7 +507,7 @@ def _collect_insert_commands(
                 and hasattr(value, "__clause_element__")
                 or isinstance(value, sql.ClauseElement)
             ):
-                value_params[col.key] = (
+                value_params[col] = (
                     value.__clause_element__()
                     if hasattr(value, "__clause_element__")
                     else value
@@ -525,7 +525,7 @@ def _collect_insert_commands(
             for colkey in (
                 mapper._insert_cols_as_none[table]
                 .difference(params)
-                .difference(value_params)
+                .difference([c.key for c in value_params])
             ):
                 params[colkey] = None
 
@@ -932,6 +932,7 @@ def _emit_update_statements(
                         c,
                         c.context.compiled_parameters[0],
                         value_params,
+                        True,
                     )
                 rows += c.rowcount
                 check_rowcount = assert_singlerow
@@ -963,6 +964,7 @@ def _emit_update_statements(
                             c,
                             c.context.compiled_parameters[0],
                             value_params,
+                            True,
                         )
                     rows += c.rowcount
             else:
@@ -998,6 +1000,7 @@ def _emit_update_statements(
                             c,
                             c.context.compiled_parameters[0],
                             value_params,
+                            True,
                         )
 
         if check_rowcount:
@@ -1086,6 +1089,7 @@ def _emit_insert_statements(
                             c,
                             last_inserted_params,
                             value_params,
+                            False,
                         )
                     else:
                         _postfetch_bulk_save(mapper_rec, state_dict, table)
@@ -1117,14 +1121,16 @@ def _emit_insert_statements(
                     )
 
                 primary_key = result.context.inserted_primary_key
-
                 if primary_key is not None:
                     # set primary key attributes
                     for pk, col in zip(
                         primary_key, mapper._pks_by_table[table]
                     ):
                         prop = mapper_rec._columntoproperty[col]
-                        if state_dict.get(prop.key) is None:
+                        if pk is not None and (
+                            col in value_params
+                            or state_dict.get(prop.key) is None
+                        ):
                             state_dict[prop.key] = pk
                 if bookkeeping:
                     if state:
@@ -1137,6 +1143,7 @@ def _emit_insert_statements(
                             result,
                             result.context.compiled_parameters[0],
                             value_params,
+                            False,
                         )
                     else:
                         _postfetch_bulk_save(mapper_rec, state_dict, table)
@@ -1461,7 +1468,15 @@ def _postfetch_post_update(
 
 
 def _postfetch(
-    mapper, uowtransaction, table, state, dict_, result, params, value_params
+    mapper,
+    uowtransaction,
+    table,
+    state,
+    dict_,
+    result,
+    params,
+    value_params,
+    isupdate,
 ):
     """Expire attributes in need of newly persisted database state,
     after an INSERT or UPDATE statement has proceeded for that
@@ -1511,6 +1526,18 @@ def _postfetch(
             state, uowtransaction, load_evt_attrs
         )
 
+    if isupdate and value_params:
+        # explicitly suit the use case specified by
+        # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
+        # database which are set to themselves in order to do a version bump.
+        postfetch_cols.extend(
+            [
+                col
+                for col in value_params
+                if col.primary_key and col not in returning_cols
+            ]
+        )
+
     if postfetch_cols:
         state._expire_attributes(
             state.dict,
index cc72073ca23f1efc2508424c66200cce8a685d00..6c9b8ee5bcf4fc27b7e8e6ed00a2d9b28236b109 100644 (file)
@@ -409,7 +409,12 @@ def _append_param_parameter(
             compiler.returning.append(c)
             value = compiler.process(value.self_group(), **kw)
         else:
-            compiler.postfetch.append(c)
+            # postfetch specifically means, "we can SELECT the row we just
+            # inserted by primary key to get back the server generated
+            # defaults". so by definition this can't be used to get the primary
+            # key value back, because we need to have it ahead of time.
+            if not c.primary_key:
+                compiler.postfetch.append(c)
             value = compiler.process(value.self_group(), **kw)
     values.append((c, value))
 
index 6326f5f1a5c5a3eea88c92aa693c21e0361d286f..dc8a818d244799604bb6b75132f4f6aad3237001 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy import event
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Integer
+from sqlalchemy import literal
 from sqlalchemy import literal_column
 from sqlalchemy import select
 from sqlalchemy import String
@@ -499,6 +500,19 @@ class ClauseAttributesTest(fixtures.MappedTest):
             Column("value", Boolean),
         )
 
+        Table(
+            "pk_t",
+            metadata,
+            Column(
+                "p_id",
+                Integer,
+                key="id",
+                autoincrement=True,
+                primary_key=True,
+            ),
+            Column("data", String(30)),
+        )
+
     @classmethod
     def setup_classes(cls):
         class User(cls.Comparable):
@@ -507,12 +521,17 @@ class ClauseAttributesTest(fixtures.MappedTest):
         class HasBoolean(cls.Comparable):
             pass
 
+        class PkDefault(cls.Comparable):
+            pass
+
     @classmethod
     def setup_mappers(cls):
         User, users_t = cls.classes.User, cls.tables.users_t
         HasBoolean, boolean_t = cls.classes.HasBoolean, cls.tables.boolean_t
+        PkDefault, pk_t = cls.classes.PkDefault, cls.tables.pk_t
         mapper(User, users_t)
         mapper(HasBoolean, boolean_t)
+        mapper(PkDefault, pk_t)
 
     def test_update(self):
         User = self.classes.User
@@ -568,6 +587,19 @@ class ClauseAttributesTest(fixtures.MappedTest):
 
         assert (u.counter == 5) is True
 
+    @testing.requires.sql_expressions_inserted_as_primary_key
+    def test_insert_pk_expression(self):
+        PkDefault = self.classes.PkDefault
+
+        pk = PkDefault(id=literal(5) + 10, data="some data")
+        session = Session()
+        session.add(pk)
+        session.flush()
+
+        eq_(pk.id, 15)
+        session.commit()
+        eq_(pk.id, 15)
+
     def test_update_special_comparator(self):
         HasBoolean = self.classes.HasBoolean
 
index c7c7206fcdbee501054afb47bb0d13008f757145..ab3f01d04a0d00ab42a8e9873ae460ad554c14cb 100644 (file)
@@ -349,6 +349,10 @@ class DefaultRequirements(SuiteRequirements):
             "postgresql", "doesn't support sequences as a server side default."
         )
 
+    @property
+    def sql_expressions_inserted_as_primary_key(self):
+        return only_if([self.returning, self.sqlite])
+
     @property
     def correlated_outer_joins(self):
         """Target must support an outer join to a subquery which
index 066e54e93c8a3f6303709ed68ff565d2bd838461..cf6715a0e6a60ab524a1af50720ff9370b339a39 100644 (file)
@@ -16,6 +16,7 @@ from sqlalchemy import table
 from sqlalchemy import text
 from sqlalchemy.dialects import mysql
 from sqlalchemy.dialects import postgresql
+from sqlalchemy.dialects import sqlite
 from sqlalchemy.engine import default
 from sqlalchemy.sql import crud
 from sqlalchemy.testing import assert_raises
@@ -1196,6 +1197,96 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             dialect=postgresql.dialect(),
         )
 
+    def test_sql_expression_pk_autoinc_lastinserted(self):
+        # test that postfetch isn't invoked for a SQL expression
+        # in a primary key column.  the DB either needs to support a lastrowid
+        # that can return it, or RETURNING.  [ticket:3133]
+        metadata = MetaData()
+        table = Table(
+            "sometable",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", String),
+        )
+
+        stmt = table.insert().return_defaults().values(id=func.foobar())
+        compiled = stmt.compile(dialect=sqlite.dialect(), column_keys=["data"])
+        eq_(compiled.postfetch, [])
+        eq_(compiled.returning, [])
+
+        self.assert_compile(
+            stmt,
+            "INSERT INTO sometable (id, data) VALUES " "(foobar(), ?)",
+            checkparams={"data": "foo"},
+            params={"data": "foo"},
+            dialect=sqlite.dialect(),
+        )
+
+    def test_sql_expression_pk_autoinc_returning(self):
+        # test that return_defaults() works with a primary key where we are
+        # sending a SQL expression, and we need to get the server-calculated
+        # value back.  [ticket:3133]
+        metadata = MetaData()
+        table = Table(
+            "sometable",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", String),
+        )
+
+        stmt = table.insert().return_defaults().values(id=func.foobar())
+        returning_dialect = postgresql.dialect()
+        returning_dialect.implicit_returning = True
+        compiled = stmt.compile(
+            dialect=returning_dialect, column_keys=["data"]
+        )
+        eq_(compiled.postfetch, [])
+        eq_(compiled.returning, [table.c.id])
+
+        self.assert_compile(
+            stmt,
+            "INSERT INTO sometable (id, data) VALUES "
+            "(foobar(), %(data)s) RETURNING sometable.id",
+            checkparams={"data": "foo"},
+            params={"data": "foo"},
+            dialect=returning_dialect,
+        )
+
+    def test_sql_expression_pk_noautoinc_returning(self):
+        # test that return_defaults() works with a primary key where we are
+        # sending a SQL expression, and we need to get the server-calculated
+        # value back.  [ticket:3133]
+        metadata = MetaData()
+        table = Table(
+            "sometable",
+            metadata,
+            Column(
+                "id",
+                Integer,
+                autoincrement=False,
+                primary_key=True,
+            ),
+            Column("data", String),
+        )
+
+        stmt = table.insert().return_defaults().values(id=func.foobar())
+        returning_dialect = postgresql.dialect()
+        returning_dialect.implicit_returning = True
+        compiled = stmt.compile(
+            dialect=returning_dialect, column_keys=["data"]
+        )
+        eq_(compiled.postfetch, [])
+        eq_(compiled.returning, [table.c.id])
+
+        self.assert_compile(
+            stmt,
+            "INSERT INTO sometable (id, data) VALUES "
+            "(foobar(), %(data)s) RETURNING sometable.id",
+            checkparams={"data": "foo"},
+            params={"data": "foo"},
+            dialect=returning_dialect,
+        )
+
     def test_python_fn_default(self):
         metadata = MetaData()
         table = Table(