]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added new generic function "next_value()", accepts
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 20 Mar 2011 16:49:28 +0000 (12:49 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 20 Mar 2011 16:49:28 +0000 (12:49 -0400)
a Sequence object as its argument and renders the
appropriate "next value" generation string on the
target platform, if supported.  Also provides
".next_value()" method on Sequence itself.
[ticket:2085]
- added tests for all the conditions described
in [ticket:2085]
- postgresql dialect will exec/compile a Sequence
that has "optional=True".  the optional flag is now only
checked specifically in the context of a Table primary key
evaulation.
- func.next_value() or other SQL expression can
be embedded directly into an insert() construct,
and if implicit or explicit "returning" is used
in conjunction with a primary key column,
the newly generated value will be present in
result.inserted_primary_key. [ticket:2084]

CHANGES
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/functions.py
test/sql/test_defaults.py

diff --git a/CHANGES b/CHANGES
index c7e872340bddd6ab1ab45c9860dc5a8be82cef34..f4d32a604426bb72039e5da895d5a348a3575c6d 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -84,6 +84,20 @@ CHANGES
     (<eventname>, <fn>), which are applied to the Table 
     before the reflection process begins.
 
+  - Added new generic function "next_value()", accepts
+    a Sequence object as its argument and renders the
+    appropriate "next value" generation string on the
+    target platform, if supported.  Also provides
+    ".next_value()" method on Sequence itself.   
+    [ticket:2085]
+
+  - func.next_value() or other SQL expression can
+    be embedded directly into an insert() construct,
+    and if implicit or explicit "returning" is used
+    in conjunction with a primary key column, 
+    the newly generated value will be present in
+    result.inserted_primary_key. [ticket:2084]
+
 - engine
   - Fixed AssertionPool regression bug.  [ticket:2097]
 
index d3c1bc1391688efe1338b22b9ff03c3a65e74d05..72411d735759266170ad4bad133c3ed92bc1113e 100644 (file)
@@ -392,7 +392,8 @@ class OracleCompiler(compiler.SQLCompiler):
             return ""
 
     def default_from(self):
-        """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
+        """Called when a ``SELECT`` statement has no froms, 
+        and no ``FROM`` clause is to be appended.
 
         The Oracle compiler tacks a "FROM DUAL" to the statement.
         """
index 8bceeef6571c8c649d5da1ca6b57d851fc5013d9..cc2f461f98517fadf0eabb4b38a783be81379fe0 100644 (file)
@@ -461,10 +461,7 @@ class PGCompiler(compiler.SQLCompiler):
         return value
 
     def visit_sequence(self, seq):
-        if seq.optional:
-            return None
-        else:
-            return "nextval('%s')" % self.preparer.format_sequence(seq)
+        return "nextval('%s')" % self.preparer.format_sequence(seq)
 
     def limit_clause(self, select):
         text = ""
@@ -717,23 +714,19 @@ class DropEnumType(schema._CreateDropBase):
 
 class PGExecutionContext(default.DefaultExecutionContext):
     def fire_sequence(self, seq, type_):
-        if not seq.optional:
-            return self._execute_scalar(("select nextval('%s')" % \
-                    self.dialect.identifier_preparer.format_sequence(seq)), type_)
-        else:
-            return None
+        return self._execute_scalar(("select nextval('%s')" % \
+                self.dialect.identifier_preparer.format_sequence(seq)), type_)
 
     def get_insert_default(self, column):
         if column.primary_key and column is column.table._autoincrement_column:
-            if (isinstance(column.server_default, schema.DefaultClause) and
-                column.server_default.arg is not None):
+            if column.server_default and column.server_default.has_argument:
 
                 # pre-execute passive defaults on primary key columns
                 return self._execute_scalar("select %s" %
-                                        column.server_default.arg, column.type)
+                                    column.server_default.arg, column.type)
 
             elif (column.default is None or 
-                        (isinstance(column.default, schema.Sequence) and
+                        (column.default.is_sequence and
                         column.default.optional)):
 
                 # execute the sequence associated with a SERIAL primary 
index 70d9013d6029789b28b69ecb7c390adfbc9b6aee..bc3eac21308dcde90da08281fab20f251e52a0bc 100644 (file)
@@ -750,7 +750,7 @@ class Column(SchemaItem, expression.ColumnClause):
                     if isinstance(self.default, str):
                     # end Py2K
                         util.warn("Unicode column received non-unicode "
-                                  "default value.")                    
+                                  "default value.")
                 args.append(ColumnDefault(self.default))
 
         if self.server_default is not None:
@@ -1266,7 +1266,16 @@ class ForeignKey(SchemaItem):
             self.constraint._set_parent_with_dispatch(table)
         table.foreign_keys.add(self)
 
-class DefaultGenerator(SchemaItem):
+class _NotAColumnExpr(object):
+    def _not_a_column_expr(self):
+        raise exc.InvalidRequestError(
+                "This %s cannot be used directly "
+                "as a column expression." % self.__class__.__name__)
+
+    __clause_element__ = self_group = lambda self: self._not_a_column_expr()
+    _from_objects = property(lambda self: self._not_a_column_expr())
+
+class DefaultGenerator(_NotAColumnExpr, SchemaItem):
     """Base class for column *default* values."""
 
     __visit_name__ = 'default_generator'
@@ -1392,26 +1401,26 @@ class ColumnDefault(DefaultGenerator):
 
 class Sequence(DefaultGenerator):
     """Represents a named database sequence.
-    
+
     The :class:`.Sequence` object represents the name and configurational
     parameters of a database sequence.   It also represents
     a construct that can be "executed" by a SQLAlchemy :class:`.Engine`
     or :class:`.Connection`, rendering the appropriate "next value" function
     for the target database and returning a result.
-    
+
     The :class:`.Sequence` is typically associated with a primary key column::
-    
+
         some_table = Table('some_table', metadata,
             Column('id', Integer, Sequence('some_table_seq'), primary_key=True)
         )
-        
+
     When CREATE TABLE is emitted for the above :class:`.Table`, if the
     target platform supports sequences, a CREATE SEQUENCE statement will
     be emitted as well.   For platforms that don't support sequences,
     the :class:`.Sequence` construct is ignored.
-    
+
     See also: :class:`.CreateSequence` :class:`.DropSequence`
-    
+
     """
 
     __visit_name__ = 'sequence'
@@ -1422,7 +1431,7 @@ class Sequence(DefaultGenerator):
                  optional=False, quote=None, metadata=None, 
                  for_update=False):
         """Construct a :class:`.Sequence` object.
-        
+
         :param name: The name of the sequence.
         :param start: the starting index of the sequence.  This value is
          used when the CREATE SEQUENCE command is emitted to the database
@@ -1455,7 +1464,7 @@ class Sequence(DefaultGenerator):
          DROP SEQUENCE DDL commands will be emitted corresponding to this
          :class:`.Sequence` when :meth:`.MetaData.create_all` and 
          :meth:`.MetaData.drop_all` are invoked (new in 0.7).
-         
+
          Note that when a :class:`.Sequence` is applied to a :class:`.Column`, 
          the :class:`.Sequence` is automatically associated with the 
          :class:`.MetaData` object of that column's parent :class:`.Table`, 
@@ -1467,7 +1476,7 @@ class Sequence(DefaultGenerator):
          with a :class:`.Column`, should be invoked for UPDATE statements
          on that column's table, rather than for INSERT statements, when
          no value is otherwise present for that column in the statement.
-         
+
         """
         super(Sequence, self).__init__(for_update=for_update)
         self.name = name
@@ -1488,6 +1497,14 @@ class Sequence(DefaultGenerator):
     def is_clause_element(self):
         return False
 
+    def next_value(self):
+        """Return a :class:`.next_value` function element
+        which will render the appropriate increment function
+        for this :class:`.Sequence` within any SQL expression.
+        
+        """
+        return expression.func.next_value(self, bind=self.bind)
+
     def __repr__(self):
         return "Sequence(%s)" % ', '.join(
             [repr(self.name)] +
@@ -1526,8 +1543,16 @@ class Sequence(DefaultGenerator):
             bind = _bind_or_error(self)
         bind.drop(self, checkfirst=checkfirst)
 
+    def _not_a_column_expr(self):
+        raise exc.InvalidRequestError(
+                "This %s cannot be used directly "
+                "as a column expression.  Use func.next_value(sequence) "
+                "to produce a 'next value' function that's usable "
+                "as a column element." 
+                % self.__class__.__name__)
 
-class FetchedValue(events.SchemaEventTarget):
+
+class FetchedValue(_NotAColumnExpr, events.SchemaEventTarget):
     """A marker for a transparent database-side default.
 
     Use :class:`.FetchedValue` when the database is configured
@@ -1544,6 +1569,7 @@ class FetchedValue(events.SchemaEventTarget):
     """
     is_server_default = True
     reflected = False
+    has_argument = False
 
     def __init__(self, for_update=False):
         self.for_update = for_update
@@ -1581,6 +1607,8 @@ class DefaultClause(FetchedValue):
 
     """
 
+    has_argument = True
+
     def __init__(self, arg, for_update=False, _reflected=False):
         util.assert_arg_type(arg, (basestring,
                                    expression.ClauseElement,
@@ -2508,7 +2536,7 @@ class DDLElement(expression.Executable, expression.ClauseElement):
               Optional keyword argument - a list of Table objects which are to
               be created/ dropped within a MetaData.create_all() or drop_all()
               method call.
-              
+
             :state:
               Optional keyword argument - will be the ``state`` argument
               passed to this function.
@@ -2517,13 +2545,13 @@ class DDLElement(expression.Executable, expression.ClauseElement):
              Keyword argument, will be True if the 'checkfirst' flag was
              set during the call to ``create()``, ``create_all()``, 
              ``drop()``, ``drop_all()``.
-             
+
           If the callable returns a true value, the DDL statement will be
           executed.
 
         :param state: any value which will be passed to the callable_ 
           as the ``state`` keyword argument.
-          
+
         See also:
 
             :class:`.DDLEvents`
index d6a020bdccd19a18d8be1f2513a8e101445a1e52..7547e1662a083f7a87628d2e8b52e589f9e39488 100644 (file)
@@ -492,6 +492,14 @@ class SQLCompiler(engine.Compiled):
             return ".".join(func.packagenames + [name]) % \
                             {'expr':self.function_argspec(func, **kwargs)}
 
+    def visit_next_value_func(self, next_value, **kw):
+        return self.visit_sequence(next_value.sequence)
+
+    def visit_sequence(self, sequence):
+        raise NotImplementedError(
+            "Dialect '%s' does not support sequence increments." % self.dialect.name
+        )
+
     def function_argspec(self, func, **kwargs):
         return func.clause_expr._compiler_dispatch(self, **kwargs)
 
@@ -926,9 +934,6 @@ class SQLCompiler(engine.Compiled):
             join.onclause._compiler_dispatch(self, **kwargs)
         )
 
-    def visit_sequence(self, seq):
-        return None
-
     def visit_insert(self, insert_stmt):
         self.isinsert = True
         colparams = self._get_colparams(insert_stmt)
@@ -1075,6 +1080,9 @@ class SQLCompiler(engine.Compiled):
                 if sql._is_literal(value):
                     value = self._create_crud_bind_param(
                                     c, value, required=value is required)
+                elif c.primary_key and implicit_returning:
+                    self.returning.append(c)
+                    value = self.process(value.self_group())
                 else:
                     self.postfetch.append(c)
                     value = self.process(value.self_group())
@@ -1092,8 +1100,10 @@ class SQLCompiler(engine.Compiled):
                     if implicit_returning:
                         if c.default is not None:
                             if c.default.is_sequence:
-                                proc = self.process(c.default)
-                                if proc is not None:
+                                if self.dialect.supports_sequences and \
+                                    (not c.default.optional or \
+                                    not self.dialect.sequences_optional):
+                                    proc = self.process(c.default)
                                     values.append((c, proc))
                                 self.returning.append(c)
                             elif c.default.is_clause_element:
@@ -1124,8 +1134,10 @@ class SQLCompiler(engine.Compiled):
 
                 elif c.default is not None:
                     if c.default.is_sequence:
-                        proc = self.process(c.default)
-                        if proc is not None:
+                        if self.dialect.supports_sequences and \
+                            (not c.default.optional or \
+                            not self.dialect.sequences_optional):
+                            proc = self.process(c.default)
                             values.append((c, proc))
                             if not c.primary_key:
                                 self.postfetch.append(c)
index 9aed957d2f6ac933c27e4a6ac3f8430500ecc71c..d49f1215079a2d680be9dac44443f92156234031 100644 (file)
@@ -1178,14 +1178,16 @@ def _column_as_key(element):
     return element.key
 
 def _literal_as_text(element):
-    if hasattr(element, '__clause_element__'):
+    if isinstance(element, Visitable):
+        return element
+    elif hasattr(element, '__clause_element__'):
         return element.__clause_element__()
     elif isinstance(element, basestring):
         return _TextClause(unicode(element))
-    elif not isinstance(element, Visitable):
-        raise exc.ArgumentError("SQL expression object or string expected.")
     else:
-        return element
+        raise exc.ArgumentError(
+            "SQL expression object or string expected."
+        )
 
 def _clause_element_as_expr(element):
     if hasattr(element, '__clause_element__'):
index 10eaa577bca039f3d9b1e423e90e3b90eb0e67bb..717816656f9a80a64517f030c8e324cce3272d49 100644 (file)
@@ -4,7 +4,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-from sqlalchemy import types as sqltypes
+from sqlalchemy import types as sqltypes, schema
 from sqlalchemy.sql.expression import (
     ClauseList, Function, _literal_as_binds, text, _type_from_args
     )
@@ -29,6 +29,29 @@ class GenericFunction(Function):
         self.type = sqltypes.to_instance(
             type_ or getattr(self, '__return_type__', None))
 
+
+class next_value(Function):
+    """Represent the 'next value', given a :class:`.Sequence`
+    as it's single argument.
+    
+    Compiles into the appropriate function on each backend,
+    or will raise NotImplementedError if used on a backend
+    that does not provide support for sequences.
+    
+    """
+    type = sqltypes.Integer()
+    name = "next_value"
+
+    def __init__(self, seq, **kw):
+        assert isinstance(seq, schema.Sequence), \
+                "next_value() accepts a Sequence object as input."
+        self._bind = kw.get('bind', None)
+        self.sequence = seq
+
+    @property
+    def _from_objects(self):
+        return []
+
 class AnsiFunction(GenericFunction):
     def __init__(self, **kwargs):
         GenericFunction.__init__(self, **kwargs)
@@ -52,6 +75,7 @@ class min(ReturnTypeFromArgs):
 class sum(ReturnTypeFromArgs):
     pass
 
+
 class now(GenericFunction):
     __return_type__ = sqltypes.DateTime
 
index eeb4ebbcb87a62200c7747d492e179931f35def4..79243ad4f69c79a872935b2015476a56a40a5f40 100644 (file)
@@ -10,7 +10,7 @@ from sqlalchemy.types import TypeDecorator
 from test.lib.schema import Table, Column
 from test.lib.testing import eq_
 from test.sql import _base
-
+from sqlalchemy.dialects import sqlite
 
 class DefaultTest(testing.TestBase):
 
@@ -303,6 +303,31 @@ class DefaultTest(testing.TestBase):
              (53, 'imthedefault', f, ts, ts, ctexec, True, False,
               12, today, 'py')])
 
+    def test_no_embed_in_sql(self):
+        """Using a DefaultGenerator, Sequence, DefaultClause
+        in the columns, where clause of a select, or in the values 
+        clause of insert, update, raises an informative error"""
+        for const in (
+            sa.Sequence('y'),
+            sa.ColumnDefault('y'),
+            sa.DefaultClause('y')
+        ):
+            assert_raises_message(
+                sa.exc.ArgumentError,
+                "SQL expression object or string expected.",
+                t.select, [const]
+            )
+            assert_raises_message(
+                sa.exc.InvalidRequestError,
+                "cannot be used directly as a column expression.",
+                str, t.insert().values(col4=const)
+            )
+            assert_raises_message(
+                sa.exc.InvalidRequestError,
+                "cannot be used directly as a column expression.",
+                str, t.update().values(col4=const)
+            )
+
     def test_missing_many_param(self):
         assert_raises_message(exc.StatementError, 
             "A value is required for bind parameter 'col7', in parameter group 1",
@@ -565,6 +590,132 @@ class SequenceDDLTest(testing.TestBase, testing.AssertsCompiledSQL):
             "DROP SEQUENCE foo_seq",
         )
 
+class SequenceExecTest(testing.TestBase):
+    __requires__ = ('sequences',)
+
+    @classmethod
+    def setup_class(cls):
+        cls.seq= Sequence("my_sequence")
+        cls.seq.create(testing.db)
+
+    @classmethod
+    def teardown_class(cls):
+        cls.seq.drop(testing.db)
+
+    def _assert_seq_result(self, ret):
+        """asserts return of next_value is an int"""
+        assert isinstance(ret, (int, long))
+        assert ret > 0
+
+    def test_implicit_connectionless(self):
+        s = Sequence("my_sequence", metadata=MetaData(testing.db))
+        self._assert_seq_result(s.execute())
+
+    def test_explicit(self):
+        s = Sequence("my_sequence")
+        self._assert_seq_result(s.execute(testing.db))
+
+    def test_explicit_optional(self):
+        """test dialect executes a Sequence, returns nextval, whether 
+        or not "optional" is set """
+
+        s = Sequence("my_sequence", optional=True)
+        self._assert_seq_result(s.execute(testing.db))
+
+    def test_func_implicit_connectionless_execute(self):
+        """test func.next_value().execute()/.scalar() works
+        with connectionless execution. """
+
+        s = Sequence("my_sequence", metadata=MetaData(testing.db))
+        self._assert_seq_result(s.next_value().execute().scalar())
+
+    def test_func_explicit(self):
+        s = Sequence("my_sequence")
+        self._assert_seq_result(testing.db.scalar(s.next_value()))
+
+    def test_func_implicit_connectionless_scalar(self):
+        """test func.next_value().execute()/.scalar() works. """
+
+        s = Sequence("my_sequence", metadata=MetaData(testing.db))
+        self._assert_seq_result(s.next_value().scalar())
+
+    def test_func_embedded_select(self):
+        """test can use next_value() in select column expr"""
+
+        s = Sequence("my_sequence")
+        self._assert_seq_result(
+            testing.db.scalar(select([s.next_value()]))
+        )
+
+    @testing.fails_on('oracle', "ORA-02287: sequence number not allowed here")
+    @testing.provide_metadata
+    def test_func_embedded_whereclause(self):
+        """test can use next_value() in whereclause"""
+
+        t1 = Table('t', metadata,
+            Column('x', Integer)
+        )
+        t1.create(testing.db)
+        testing.db.execute(t1.insert(), [{'x':1}, {'x':300}, {'x':301}])
+        s = Sequence("my_sequence")
+        eq_(
+            testing.db.execute(
+                t1.select().where(t1.c.x > s.next_value())
+            ).fetchall(),
+            [(300, ), (301, )]
+        )
+
+    @testing.provide_metadata
+    def test_func_embedded_valuesbase(self):
+        """test can use next_value() in values() of _ValuesBase"""
+
+        t1 = Table('t', metadata,
+            Column('x', Integer)
+        )
+        t1.create(testing.db)
+        s = Sequence("my_sequence")
+        testing.db.execute(
+            t1.insert().values(x=s.next_value())
+        )
+        self._assert_seq_result(
+            testing.db.scalar(t1.select())
+        )
+
+    @testing.provide_metadata
+    def test_inserted_pk_no_returning(self):
+        """test inserted_primary_key contains [None] when 
+        pk_col=next_value(), implicit returning is not used."""
+
+        e = engines.testing_engine(options={'implicit_returning':False})
+        s = Sequence("my_sequence")
+        metadata.bind = e
+        t1 = Table('t', metadata,
+            Column('x', Integer, primary_key=True)
+        )
+        t1.create()
+        r = e.execute(
+            t1.insert().values(x=s.next_value())
+        )
+        eq_(r.inserted_primary_key, [None])
+
+    @testing.requires.returning
+    @testing.provide_metadata
+    def test_inserted_pk_implicit_returning(self):
+        """test inserted_primary_key contains the result when 
+        pk_col=next_value(), when implicit returning is used."""
+
+        e = engines.testing_engine(options={'implicit_returning':True})
+        s = Sequence("my_sequence")
+        metadata.bind = e
+        t1 = Table('t', metadata,
+            Column('x', Integer, primary_key=True)
+        )
+        t1.create()
+        r = e.execute(
+            t1.insert().values(x=s.next_value())
+        )
+        self._assert_seq_result(r.inserted_primary_key[0])
+
 class SequenceTest(testing.TestBase, testing.AssertsCompiledSQL):
     __requires__ = ('sequences',)
 
@@ -586,29 +737,36 @@ class SequenceTest(testing.TestBase, testing.AssertsCompiledSQL):
             finally:
                 seq.drop(testing.db)
 
-    @testing.fails_on('maxdb', 'FIXME: unknown')
-    # maxdb db-api seems to double-execute NEXTVAL 
-    # internally somewhere,
-    # throwing off the numbers for these tests...
-    @testing.provide_metadata
-    def test_implicit_sequence_exec(self):
-        s = Sequence("my_sequence", metadata=metadata)
-        metadata.create_all()
-        x = s.execute()
-        eq_(x, 1)
 
     def _has_sequence(self, name):
         return testing.db.dialect.has_sequence(testing.db, name)
 
-    @testing.fails_on('maxdb', 'FIXME: unknown')
-    def teststandalone_explicit(self):
-        s = Sequence("my_sequence")
-        s.create(bind=testing.db)
-        try:
-            x = s.execute(testing.db)
-            eq_(x, 1)
-        finally:
-            s.drop(testing.db)
+    def test_nextval_render(self):
+        """test dialect renders the "nextval" construct, 
+        whether or not "optional" is set """
+
+        for s in (
+                Sequence("my_seq"), 
+                Sequence("my_seq", optional=True)):
+            assert str(s.next_value().
+                    compile(dialect=testing.db.dialect)) in (
+                "nextval('my_seq')",
+                "gen_id(my_seq, 1)",
+                "my_seq.nextval",
+            )
+
+    def test_nextval_unsupported(self):
+        """test next_value() used on non-sequence platform 
+        raises NotImplementedError."""
+
+        s = Sequence("my_seq")
+        d = sqlite.dialect()
+        assert_raises_message(
+            NotImplementedError,
+            "Dialect 'sqlite' does not support sequence increments.",
+            s.next_value().compile,
+            dialect=d
+        )
 
     def test_checkfirst_sequence(self):
         s = Sequence("my_sequence")
@@ -733,10 +891,10 @@ class TableBoundSequenceTest(testing.TestBase):
 
 class SpecialTypePKTest(testing.TestBase):
     """test process_result_value in conjunction with primary key columns.
-    
+
     Also tests that "autoincrement" checks are against column.type._type_affinity,
     rather than the class of "type" itself.
-    
+
     """
 
     @classmethod
@@ -818,12 +976,12 @@ class ServerDefaultsOnPKTest(testing.TestBase):
     @testing.provide_metadata
     def test_string_default_none_on_insert(self):
         """Test that without implicit returning, we return None for 
-        a string server default.  
-        
+        a string server default.
+
         That is, we don't want to attempt to pre-execute "server_default"
         generically - the user should use a Python side-default for a case
         like this.   Testing that all backends do the same thing here.
-        
+
         """
         t = Table('x', metadata, 
                 Column('y', String(10), server_default='key_one', primary_key=True),
@@ -946,7 +1104,7 @@ class UnicodeDefaultsTest(testing.TestBase):
         # end Py2K
         c = Column(Unicode(32), default=default)
 
-    
+
     def test_nonunicode_default(self):
         # Py3K
         #default = b'foo'