]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- :meth:`.Insert.from_select` now includes Python and SQL-expression
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 10 Oct 2014 21:15:19 +0000 (17:15 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 10 Oct 2014 21:15:19 +0000 (17:15 -0400)
defaults if otherwise unspecified; the limitation where non-
server column defaults aren't included in an INSERT FROM
SELECT is now lifted and these expressions are rendered as
constants into the SELECT statement.

doc/build/changelog/changelog_10.rst
doc/build/changelog/migration_10.rst
doc/build/core/defaults.rst
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/testing/suite/test_insert.py
test/sql/test_defaults.py
test/sql/test_insert.py

index 5b0362c440e12bdbf44c9f2e5876613486fd641f..3d471f1929e2923a8203340ee35a9d08814068bb 100644 (file)
     series as well.  For changes that are specific to 1.0 with an emphasis
     on compatibility concerns, see :doc:`/changelog/migration_10`.
 
+    .. change::
+        :tags: feature, sql
+
+        :meth:`.Insert.from_select` now includes Python and SQL-expression
+        defaults if otherwise unspecified; the limitation where non-
+        server column defaults aren't included in an INSERT FROM
+        SELECT is now lifted and these expressions are rendered as
+        constants into the SELECT statement.
+
+        .. seealso::
+
+            :ref:`feature_insert_from_select_defaults`
+
     .. change::
         :tags: bug, orm
         :tickets: 3222
index a3b0c03086ee109f005cc54b8163bd132f0ccca1..951e3960306a03d9d75eba2bc7a617517d03ab5b 100644 (file)
@@ -187,6 +187,36 @@ than the integer value.
 
 .. _change_2051:
 
+.. _feature_insert_from_select_defaults:
+
+INSERT FROM SELECT now includes Python and SQL-expression defaults
+-------------------------------------------------------------------
+
+:meth:`.Insert.from_select` now includes Python and SQL-expression defaults if
+otherwise unspecified; the limitation where non-server column defaults
+aren't included in an INSERT FROM SELECT is now lifted and these
+expressions are rendered as constants into the SELECT statement::
+
+    from sqlalchemy import Table, Column, MetaData, Integer, select, func
+
+    m = MetaData()
+
+    t = Table(
+        't', m,
+        Column('x', Integer),
+        Column('y', Integer, default=func.somefunction()))
+
+    stmt = select([t.c.x])
+    print t.insert().from_select(['x'], stmt)
+
+Will render::
+
+    INSERT INTO t (x, y) SELECT t.x, somefunction() AS somefunction_1
+    FROM t
+
+The feature can be disabled using
+:paramref:`.Insert.from_select.include_defaults`.
+
 New Postgresql Table options
 -----------------------------
 
index 166273c186e84a96cd311308f4f47e5c7558c78a..1d55cd6c62a0f31d78b94556dd94180ee097dbc8 100644 (file)
@@ -1,6 +1,8 @@
+.. module:: sqlalchemy.schema
+
 .. _metadata_defaults_toplevel:
+
 .. _metadata_defaults:
-.. module:: sqlalchemy.schema
 
 Column Insert/Update Defaults
 ==============================
index 86f00d94421a4973339439b7fa3a3eb6b69e9c8a..a6c30b7dc95158027d9e778a0b2daed9fc93e6b2 100644 (file)
@@ -1793,7 +1793,7 @@ class SQLCompiler(Compiled):
                 text += " " + returning_clause
 
         if insert_stmt.select is not None:
-            text += " %s" % self.process(insert_stmt.select, **kw)
+            text += " %s" % self.process(self._insert_from_select, **kw)
         elif not crud_params and supports_default_values:
             text += " DEFAULT VALUES"
         elif insert_stmt._has_multi_parameters:
index 1c1f661d2b1718775cdd1f82db2cab39999b407b..831d05be1ef1fa238317d2fb63ac3904bdc92650 100644 (file)
@@ -89,18 +89,15 @@ def _get_crud_params(compiler, stmt, **kw):
             _col_bind_name, _getattr_col_key, values, kw)
 
     if compiler.isinsert and stmt.select_names:
-        # for an insert from select, we can only use names that
-        # are given, so only select for those names.
-        cols = (stmt.table.c[_column_as_key(name)]
-                for name in stmt.select_names)
+        _scan_insert_from_select_cols(
+            compiler, stmt, parameters,
+            _getattr_col_key, _column_as_key,
+            _col_bind_name, check_columns, values, kw)
     else:
-        # iterate through all table columns to maintain
-        # ordering, even for those cols that aren't included
-        cols = stmt.table.columns
-
-    _scan_cols(
-        compiler, stmt, cols, parameters,
-        _getattr_col_key, _col_bind_name, check_columns, values, kw)
+        _scan_cols(
+            compiler, stmt, parameters,
+            _getattr_col_key, _column_as_key,
+            _col_bind_name, check_columns, values, kw)
 
     if parameters and stmt_parameters:
         check = set(parameters).intersection(
@@ -118,13 +115,17 @@ def _get_crud_params(compiler, stmt, **kw):
     return values
 
 
-def _create_bind_param(compiler, col, value, required=False, name=None):
+def _create_bind_param(
+        compiler, col, value, process=True, required=False, name=None):
     if name is None:
         name = col.key
     bindparam = elements.BindParameter(name, value,
                                        type_=col.type, required=required)
     bindparam._is_crud = True
-    return bindparam._compiler_dispatch(compiler)
+    if process:
+        bindparam = bindparam._compiler_dispatch(compiler)
+    return bindparam
+
 
 def _key_getters_for_crud_column(compiler):
     if compiler.isupdate and compiler.statement._extra_froms:
@@ -162,14 +163,52 @@ def _key_getters_for_crud_column(compiler):
     return _column_as_key, _getattr_col_key, _col_bind_name
 
 
+def _scan_insert_from_select_cols(
+    compiler, stmt, parameters, _getattr_col_key,
+        _column_as_key, _col_bind_name, check_columns, values, kw):
+
+    need_pks, implicit_returning, \
+        implicit_return_defaults, postfetch_lastrowid = \
+        _get_returning_modifiers(compiler, stmt)
+
+    cols = [stmt.table.c[_column_as_key(name)]
+            for name in stmt.select_names]
+
+    compiler._insert_from_select = stmt.select
+
+    add_select_cols = []
+    if stmt.include_insert_from_select_defaults:
+        col_set = set(cols)
+        for col in stmt.table.columns:
+            if col not in col_set and col.default:
+                cols.append(col)
+
+    for c in cols:
+        col_key = _getattr_col_key(c)
+        if col_key in parameters and col_key not in check_columns:
+            parameters.pop(col_key)
+            values.append((c, None))
+        else:
+            _append_param_insert_select_hasdefault(
+                compiler, stmt, c, add_select_cols, kw)
+
+    if add_select_cols:
+        values.extend(add_select_cols)
+        compiler._insert_from_select = compiler._insert_from_select._generate()
+        compiler._insert_from_select._raw_columns += tuple(
+            expr for col, expr in add_select_cols)
+
+
 def _scan_cols(
-    compiler, stmt, cols, parameters, _getattr_col_key,
-        _col_bind_name, check_columns, values, kw):
+    compiler, stmt, parameters, _getattr_col_key,
+        _column_as_key, _col_bind_name, check_columns, values, kw):
 
     need_pks, implicit_returning, \
         implicit_return_defaults, postfetch_lastrowid = \
         _get_returning_modifiers(compiler, stmt)
 
+    cols = stmt.table.columns
+
     for c in cols:
         col_key = _getattr_col_key(c)
         if col_key in parameters and col_key not in check_columns:
@@ -196,7 +235,8 @@ def _scan_cols(
             elif c.default is not None:
 
                 _append_param_insert_hasdefault(
-                    compiler, stmt, c, implicit_return_defaults, values, kw)
+                    compiler, stmt, c, implicit_return_defaults,
+                    values, kw)
 
             elif c.server_default is not None:
                 if implicit_return_defaults and \
@@ -299,10 +339,8 @@ def _append_param_insert_hasdefault(
             elif not c.primary_key:
                 compiler.postfetch.append(c)
     elif c.default.is_clause_element:
-        values.append(
-            (c, compiler.process(
-                c.default.arg.self_group(), **kw))
-        )
+        proc = compiler.process(c.default.arg.self_group(), **kw)
+        values.append((c, proc))
 
         if implicit_return_defaults and \
                 c in implicit_return_defaults:
@@ -317,6 +355,25 @@ def _append_param_insert_hasdefault(
         compiler.prefetch.append(c)
 
 
+def _append_param_insert_select_hasdefault(
+        compiler, stmt, c, values, kw):
+
+    if c.default.is_sequence:
+        if compiler.dialect.supports_sequences and \
+            (not c.default.optional or
+             not compiler.dialect.sequences_optional):
+            proc = c.default
+            values.append((c, proc))
+    elif c.default.is_clause_element:
+        proc = c.default.arg.self_group()
+        values.append((c, proc))
+    else:
+        values.append(
+            (c, _create_bind_param(compiler, c, None, process=False))
+        )
+        compiler.prefetch.append(c)
+
+
 def _append_param_update(
         compiler, stmt, c, implicit_return_defaults, values, kw):
 
index 1934d0776347929a2f4b8aa80b1341eaa00a6c36..9f2ce7ce39d994cb7c3841f4c977c45661cafbe8 100644 (file)
@@ -475,6 +475,7 @@ class Insert(ValuesBase):
         ValuesBase.__init__(self, table, values, prefixes)
         self._bind = bind
         self.select = self.select_names = None
+        self.include_insert_from_select_defaults = False
         self.inline = inline
         self._returning = returning
         self._validate_dialect_kwargs(dialect_kw)
@@ -487,7 +488,7 @@ class Insert(ValuesBase):
             return ()
 
     @_generative
-    def from_select(self, names, select):
+    def from_select(self, names, select, include_defaults=True):
         """Return a new :class:`.Insert` construct which represents
         an ``INSERT...FROM SELECT`` statement.
 
@@ -506,6 +507,21 @@ class Insert(ValuesBase):
          is not checked before passing along to the database, the database
          would normally raise an exception if these column lists don't
          correspond.
+        :param include_defaults: if True, non-server default values and
+         SQL expressions as specified on :class:`.Column` objects
+         (as documented in :ref:`metadata_defaults_toplevel`) not
+         otherwise specified in the list of names will be rendered
+         into the INSERT and SELECT statements, so that these values are also
+         included in the data to be inserted.
+
+         .. note:: A Python-side default that uses a Python callable function
+            will only be invoked **once** for the whole statement, and **not
+            per row**.
+
+         .. versionadded:: 1.0.0 - :meth:`.Insert.from_select` now renders
+            Python-side and SQL expression column defaults into the
+            SELECT statement for columns otherwise not included in the
+            list of column names.
 
         .. versionchanged:: 1.0.0 an INSERT that uses FROM SELECT
            implies that the :paramref:`.insert.inline` flag is set to
@@ -514,13 +530,6 @@ class Insert(ValuesBase):
            deals with an arbitrary number of rows, so the
            :attr:`.ResultProxy.inserted_primary_key` accessor does not apply.
 
-        .. note::
-
-           A SELECT..INSERT construct in SQL has no VALUES clause.  Therefore
-           :class:`.Column` objects which utilize Python-side defaults
-           (e.g. as described at :ref:`metadata_defaults_toplevel`)
-           will **not** take effect when using :meth:`.Insert.from_select`.
-
         .. versionadded:: 0.8.3
 
         """
@@ -533,6 +542,7 @@ class Insert(ValuesBase):
 
         self.select_names = names
         self.inline = True
+        self.include_insert_from_select_defaults = include_defaults
         self.select = _interpret_as_select(select)
 
     def _copy_internals(self, clone=_clone, **kw):
index 92d3d93e5ed1aac5a191c2215b9d40dd3e4b83c5..c197145c7be13e852a5f5ee7fae621a61e0016d1 100644 (file)
@@ -4,7 +4,7 @@ from .. import exclusions
 from ..assertions import eq_
 from .. import engines
 
-from sqlalchemy import Integer, String, select, util
+from sqlalchemy import Integer, String, select, literal_column
 
 from ..schema import Table, Column
 
@@ -90,6 +90,13 @@ class InsertBehaviorTest(fixtures.TablesTest):
               Column('id', Integer, primary_key=True, autoincrement=False),
               Column('data', String(50))
               )
+        Table('includes_defaults', metadata,
+              Column('id', Integer, primary_key=True,
+                     test_needs_autoincrement=True),
+              Column('data', String(50)),
+              Column('x', Integer, default=5),
+              Column('y', Integer,
+                     default=literal_column("2", type_=Integer) + 2))
 
     def test_autoclose_on_insert(self):
         if requirements.returning.enabled:
@@ -158,6 +165,34 @@ class InsertBehaviorTest(fixtures.TablesTest):
                 ("data3", ), ("data3", )]
         )
 
+    @requirements.insert_from_select
+    def test_insert_from_select_with_defaults(self):
+        table = self.tables.includes_defaults
+        config.db.execute(
+            table.insert(),
+            [
+                dict(id=1, data="data1"),
+                dict(id=2, data="data2"),
+                dict(id=3, data="data3"),
+            ]
+        )
+
+        config.db.execute(
+            table.insert(inline=True).
+            from_select(("id", "data",),
+                        select([table.c.id + 5, table.c.data]).
+                        where(table.c.data.in_(["data2", "data3"]))
+                        ),
+        )
+
+        eq_(
+            config.db.execute(
+                select([table]).order_by(table.c.data)
+            ).fetchall(),
+            [(1, 'data1', 5, 4), (2, 'data2', 5, 4),
+                (7, 'data2', 5, 4), (3, 'data3', 5, 4), (8, 'data3', 5, 4)]
+        )
+
 
 class ReturningTest(fixtures.TablesTest):
     run_create_tables = 'each'
index abce600df429472d1bc4100962cff6ee442e2037..10e557b76cf6f185b740be9118d4d673ce8df622 100644 (file)
@@ -14,6 +14,7 @@ from sqlalchemy.dialects import sqlite
 from sqlalchemy.testing import fixtures
 from sqlalchemy.util import u, b
 from sqlalchemy import util
+import itertools
 
 t = f = f2 = ts = currenttime = metadata = default_generator = None
 
@@ -1278,3 +1279,67 @@ class UnicodeDefaultsTest(fixtures.TestBase):
             "foobar", Unicode(32),
             default=default
         )
+
+
+class InsertFromSelectTest(fixtures.TestBase):
+    __backend__ = True
+
+    def _fixture(self):
+        data = Table(
+            'data', self.metadata,
+            Column('x', Integer),
+            Column('y', Integer)
+        )
+        data.create()
+        testing.db.execute(data.insert(), {'x': 2, 'y': 5}, {'x': 7, 'y': 12})
+        return data
+
+    @testing.provide_metadata
+    def test_insert_from_select_override_defaults(self):
+        data = self._fixture()
+
+        table = Table('sometable', self.metadata,
+                      Column('x', Integer),
+                      Column('foo', Integer, default=12),
+                      Column('y', Integer))
+
+        table.create()
+
+        sel = select([data.c.x, data.c.y])
+
+        ins = table.insert().\
+            from_select(["x", "y"], sel)
+        testing.db.execute(ins)
+
+        eq_(
+            testing.db.execute(table.select().order_by(table.c.x)).fetchall(),
+            [(2, 12, 5), (7, 12, 12)]
+        )
+
+    @testing.provide_metadata
+    def test_insert_from_select_fn_defaults(self):
+        data = self._fixture()
+
+        counter = itertools.count(1)
+
+        def foo(ctx):
+            return next(counter)
+
+        table = Table('sometable', self.metadata,
+                      Column('x', Integer),
+                      Column('foo', Integer, default=foo),
+                      Column('y', Integer))
+
+        table.create()
+
+        sel = select([data.c.x, data.c.y])
+
+        ins = table.insert().\
+            from_select(["x", "y"], sel)
+        testing.db.execute(ins)
+
+        # counter is only called once!
+        eq_(
+            testing.db.execute(table.select().order_by(table.c.x)).fetchall(),
+            [(2, 1, 5), (7, 1, 12)]
+        )
index 232c5758bca2e61e6cb859f8f38e6700b3682ae8..bd4eaa3e29b4b5b7bfbc4f2abf88a0a45bb7b66c 100644 (file)
@@ -183,7 +183,7 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             checkparams={"name_1": "foo"}
         )
 
-    def test_insert_from_select_select_no_defaults(self):
+    def test_insert_from_select_no_defaults(self):
         metadata = MetaData()
         table = Table('sometable', metadata,
                       Column('id', Integer, primary_key=True),
@@ -191,7 +191,7 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
         table1 = self.tables.mytable
         sel = select([table1.c.myid]).where(table1.c.name == 'foo')
         ins = table.insert().\
-            from_select(["id"], sel)
+            from_select(["id"], sel, include_defaults=False)
         self.assert_compile(
             ins,
             "INSERT INTO sometable (id) SELECT mytable.myid "
@@ -199,6 +199,84 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             checkparams={"name_1": "foo"}
         )
 
+    def test_insert_from_select_with_sql_defaults(self):
+        metadata = MetaData()
+        table = Table('sometable', metadata,
+                      Column('id', Integer, primary_key=True),
+                      Column('foo', Integer, default=func.foobar()))
+        table1 = self.tables.mytable
+        sel = select([table1.c.myid]).where(table1.c.name == 'foo')
+        ins = table.insert().\
+            from_select(["id"], sel)
+        self.assert_compile(
+            ins,
+            "INSERT INTO sometable (id, foo) SELECT "
+            "mytable.myid, foobar() AS foobar_1 "
+            "FROM mytable WHERE mytable.name = :name_1",
+            checkparams={"name_1": "foo"}
+        )
+
+    def test_insert_from_select_with_python_defaults(self):
+        metadata = MetaData()
+        table = Table('sometable', metadata,
+                      Column('id', Integer, primary_key=True),
+                      Column('foo', Integer, default=12))
+        table1 = self.tables.mytable
+        sel = select([table1.c.myid]).where(table1.c.name == 'foo')
+        ins = table.insert().\
+            from_select(["id"], sel)
+        self.assert_compile(
+            ins,
+            "INSERT INTO sometable (id, foo) SELECT "
+            "mytable.myid, :foo AS anon_1 "
+            "FROM mytable WHERE mytable.name = :name_1",
+            # value filled in at execution time
+            checkparams={"name_1": "foo", "foo": None}
+        )
+
+    def test_insert_from_select_override_defaults(self):
+        metadata = MetaData()
+        table = Table('sometable', metadata,
+                      Column('id', Integer, primary_key=True),
+                      Column('foo', Integer, default=12))
+        table1 = self.tables.mytable
+        sel = select(
+            [table1.c.myid, table1.c.myid.label('q')]).where(
+            table1.c.name == 'foo')
+        ins = table.insert().\
+            from_select(["id", "foo"], sel)
+        self.assert_compile(
+            ins,
+            "INSERT INTO sometable (id, foo) SELECT "
+            "mytable.myid, mytable.myid AS q "
+            "FROM mytable WHERE mytable.name = :name_1",
+            checkparams={"name_1": "foo"}
+        )
+
+    def test_insert_from_select_fn_defaults(self):
+        metadata = MetaData()
+
+        def foo(ctx):
+            return 12
+
+        table = Table('sometable', metadata,
+                      Column('id', Integer, primary_key=True),
+                      Column('foo', Integer, default=foo))
+        table1 = self.tables.mytable
+        sel = select(
+            [table1.c.myid]).where(
+            table1.c.name == 'foo')
+        ins = table.insert().\
+            from_select(["id"], sel)
+        self.assert_compile(
+            ins,
+            "INSERT INTO sometable (id, foo) SELECT "
+            "mytable.myid, :foo AS anon_1 "
+            "FROM mytable WHERE mytable.name = :name_1",
+            # value filled in at execution time
+            checkparams={"name_1": "foo", "foo": None}
+        )
+
     def test_insert_mix_select_values_exception(self):
         table1 = self.tables.mytable
         sel = select([table1.c.myid, table1.c.name]).where(