From: Mike Bayer Date: Sun, 20 Mar 2011 16:49:28 +0000 (-0400) Subject: - Added new generic function "next_value()", accepts X-Git-Tag: rel_0_7b3~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=90335a89a98df23db7a3ae1233eb4fbb5743d2e8;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Added new generic function "next_value()", accepts a Sequence object as its argument and renders the appropriate "next value" generation string on the target platform, if supported. Also provides ".next_value()" method on Sequence itself. [ticket:2085] - added tests for all the conditions described in [ticket:2085] - postgresql dialect will exec/compile a Sequence that has "optional=True". the optional flag is now only checked specifically in the context of a Table primary key evaulation. - func.next_value() or other SQL expression can be embedded directly into an insert() construct, and if implicit or explicit "returning" is used in conjunction with a primary key column, the newly generated value will be present in result.inserted_primary_key. [ticket:2084] --- diff --git a/CHANGES b/CHANGES index c7e872340b..f4d32a6044 100644 --- a/CHANGES +++ b/CHANGES @@ -84,6 +84,20 @@ CHANGES (, ), which are applied to the Table before the reflection process begins. + - Added new generic function "next_value()", accepts + a Sequence object as its argument and renders the + appropriate "next value" generation string on the + target platform, if supported. Also provides + ".next_value()" method on Sequence itself. + [ticket:2085] + + - func.next_value() or other SQL expression can + be embedded directly into an insert() construct, + and if implicit or explicit "returning" is used + in conjunction with a primary key column, + the newly generated value will be present in + result.inserted_primary_key. [ticket:2084] + - engine - Fixed AssertionPool regression bug. [ticket:2097] diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index d3c1bc1391..72411d7357 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -392,7 +392,8 @@ class OracleCompiler(compiler.SQLCompiler): return "" def default_from(self): - """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. + """Called when a ``SELECT`` statement has no froms, + and no ``FROM`` clause is to be appended. The Oracle compiler tacks a "FROM DUAL" to the statement. """ diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 8bceeef657..cc2f461f98 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -461,10 +461,7 @@ class PGCompiler(compiler.SQLCompiler): return value def visit_sequence(self, seq): - if seq.optional: - return None - else: - return "nextval('%s')" % self.preparer.format_sequence(seq) + return "nextval('%s')" % self.preparer.format_sequence(seq) def limit_clause(self, select): text = "" @@ -717,23 +714,19 @@ class DropEnumType(schema._CreateDropBase): class PGExecutionContext(default.DefaultExecutionContext): def fire_sequence(self, seq, type_): - if not seq.optional: - return self._execute_scalar(("select nextval('%s')" % \ - self.dialect.identifier_preparer.format_sequence(seq)), type_) - else: - return None + return self._execute_scalar(("select nextval('%s')" % \ + self.dialect.identifier_preparer.format_sequence(seq)), type_) def get_insert_default(self, column): if column.primary_key and column is column.table._autoincrement_column: - if (isinstance(column.server_default, schema.DefaultClause) and - column.server_default.arg is not None): + if column.server_default and column.server_default.has_argument: # pre-execute passive defaults on primary key columns return self._execute_scalar("select %s" % - column.server_default.arg, column.type) + column.server_default.arg, column.type) elif (column.default is None or - (isinstance(column.default, schema.Sequence) and + (column.default.is_sequence and column.default.optional)): # execute the sequence associated with a SERIAL primary diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 70d9013d60..bc3eac2130 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -750,7 +750,7 @@ class Column(SchemaItem, expression.ColumnClause): if isinstance(self.default, str): # end Py2K util.warn("Unicode column received non-unicode " - "default value.") + "default value.") args.append(ColumnDefault(self.default)) if self.server_default is not None: @@ -1266,7 +1266,16 @@ class ForeignKey(SchemaItem): self.constraint._set_parent_with_dispatch(table) table.foreign_keys.add(self) -class DefaultGenerator(SchemaItem): +class _NotAColumnExpr(object): + def _not_a_column_expr(self): + raise exc.InvalidRequestError( + "This %s cannot be used directly " + "as a column expression." % self.__class__.__name__) + + __clause_element__ = self_group = lambda self: self._not_a_column_expr() + _from_objects = property(lambda self: self._not_a_column_expr()) + +class DefaultGenerator(_NotAColumnExpr, SchemaItem): """Base class for column *default* values.""" __visit_name__ = 'default_generator' @@ -1392,26 +1401,26 @@ class ColumnDefault(DefaultGenerator): class Sequence(DefaultGenerator): """Represents a named database sequence. - + The :class:`.Sequence` object represents the name and configurational parameters of a database sequence. It also represents a construct that can be "executed" by a SQLAlchemy :class:`.Engine` or :class:`.Connection`, rendering the appropriate "next value" function for the target database and returning a result. - + The :class:`.Sequence` is typically associated with a primary key column:: - + some_table = Table('some_table', metadata, Column('id', Integer, Sequence('some_table_seq'), primary_key=True) ) - + When CREATE TABLE is emitted for the above :class:`.Table`, if the target platform supports sequences, a CREATE SEQUENCE statement will be emitted as well. For platforms that don't support sequences, the :class:`.Sequence` construct is ignored. - + See also: :class:`.CreateSequence` :class:`.DropSequence` - + """ __visit_name__ = 'sequence' @@ -1422,7 +1431,7 @@ class Sequence(DefaultGenerator): optional=False, quote=None, metadata=None, for_update=False): """Construct a :class:`.Sequence` object. - + :param name: The name of the sequence. :param start: the starting index of the sequence. This value is used when the CREATE SEQUENCE command is emitted to the database @@ -1455,7 +1464,7 @@ class Sequence(DefaultGenerator): DROP SEQUENCE DDL commands will be emitted corresponding to this :class:`.Sequence` when :meth:`.MetaData.create_all` and :meth:`.MetaData.drop_all` are invoked (new in 0.7). - + Note that when a :class:`.Sequence` is applied to a :class:`.Column`, the :class:`.Sequence` is automatically associated with the :class:`.MetaData` object of that column's parent :class:`.Table`, @@ -1467,7 +1476,7 @@ class Sequence(DefaultGenerator): with a :class:`.Column`, should be invoked for UPDATE statements on that column's table, rather than for INSERT statements, when no value is otherwise present for that column in the statement. - + """ super(Sequence, self).__init__(for_update=for_update) self.name = name @@ -1488,6 +1497,14 @@ class Sequence(DefaultGenerator): def is_clause_element(self): return False + def next_value(self): + """Return a :class:`.next_value` function element + which will render the appropriate increment function + for this :class:`.Sequence` within any SQL expression. + + """ + return expression.func.next_value(self, bind=self.bind) + def __repr__(self): return "Sequence(%s)" % ', '.join( [repr(self.name)] + @@ -1526,8 +1543,16 @@ class Sequence(DefaultGenerator): bind = _bind_or_error(self) bind.drop(self, checkfirst=checkfirst) + def _not_a_column_expr(self): + raise exc.InvalidRequestError( + "This %s cannot be used directly " + "as a column expression. Use func.next_value(sequence) " + "to produce a 'next value' function that's usable " + "as a column element." + % self.__class__.__name__) -class FetchedValue(events.SchemaEventTarget): + +class FetchedValue(_NotAColumnExpr, events.SchemaEventTarget): """A marker for a transparent database-side default. Use :class:`.FetchedValue` when the database is configured @@ -1544,6 +1569,7 @@ class FetchedValue(events.SchemaEventTarget): """ is_server_default = True reflected = False + has_argument = False def __init__(self, for_update=False): self.for_update = for_update @@ -1581,6 +1607,8 @@ class DefaultClause(FetchedValue): """ + has_argument = True + def __init__(self, arg, for_update=False, _reflected=False): util.assert_arg_type(arg, (basestring, expression.ClauseElement, @@ -2508,7 +2536,7 @@ class DDLElement(expression.Executable, expression.ClauseElement): Optional keyword argument - a list of Table objects which are to be created/ dropped within a MetaData.create_all() or drop_all() method call. - + :state: Optional keyword argument - will be the ``state`` argument passed to this function. @@ -2517,13 +2545,13 @@ class DDLElement(expression.Executable, expression.ClauseElement): Keyword argument, will be True if the 'checkfirst' flag was set during the call to ``create()``, ``create_all()``, ``drop()``, ``drop_all()``. - + If the callable returns a true value, the DDL statement will be executed. :param state: any value which will be passed to the callable_ as the ``state`` keyword argument. - + See also: :class:`.DDLEvents` diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index d6a020bdcc..7547e1662a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -492,6 +492,14 @@ class SQLCompiler(engine.Compiled): return ".".join(func.packagenames + [name]) % \ {'expr':self.function_argspec(func, **kwargs)} + def visit_next_value_func(self, next_value, **kw): + return self.visit_sequence(next_value.sequence) + + def visit_sequence(self, sequence): + raise NotImplementedError( + "Dialect '%s' does not support sequence increments." % self.dialect.name + ) + def function_argspec(self, func, **kwargs): return func.clause_expr._compiler_dispatch(self, **kwargs) @@ -926,9 +934,6 @@ class SQLCompiler(engine.Compiled): join.onclause._compiler_dispatch(self, **kwargs) ) - def visit_sequence(self, seq): - return None - def visit_insert(self, insert_stmt): self.isinsert = True colparams = self._get_colparams(insert_stmt) @@ -1075,6 +1080,9 @@ class SQLCompiler(engine.Compiled): if sql._is_literal(value): value = self._create_crud_bind_param( c, value, required=value is required) + elif c.primary_key and implicit_returning: + self.returning.append(c) + value = self.process(value.self_group()) else: self.postfetch.append(c) value = self.process(value.self_group()) @@ -1092,8 +1100,10 @@ class SQLCompiler(engine.Compiled): if implicit_returning: if c.default is not None: if c.default.is_sequence: - proc = self.process(c.default) - if proc is not None: + if self.dialect.supports_sequences and \ + (not c.default.optional or \ + not self.dialect.sequences_optional): + proc = self.process(c.default) values.append((c, proc)) self.returning.append(c) elif c.default.is_clause_element: @@ -1124,8 +1134,10 @@ class SQLCompiler(engine.Compiled): elif c.default is not None: if c.default.is_sequence: - proc = self.process(c.default) - if proc is not None: + if self.dialect.supports_sequences and \ + (not c.default.optional or \ + not self.dialect.sequences_optional): + proc = self.process(c.default) values.append((c, proc)) if not c.primary_key: self.postfetch.append(c) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 9aed957d2f..d49f121507 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1178,14 +1178,16 @@ def _column_as_key(element): return element.key def _literal_as_text(element): - if hasattr(element, '__clause_element__'): + if isinstance(element, Visitable): + return element + elif hasattr(element, '__clause_element__'): return element.__clause_element__() elif isinstance(element, basestring): return _TextClause(unicode(element)) - elif not isinstance(element, Visitable): - raise exc.ArgumentError("SQL expression object or string expected.") else: - return element + raise exc.ArgumentError( + "SQL expression object or string expected." + ) def _clause_element_as_expr(element): if hasattr(element, '__clause_element__'): diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 10eaa577bc..717816656f 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import types as sqltypes +from sqlalchemy import types as sqltypes, schema from sqlalchemy.sql.expression import ( ClauseList, Function, _literal_as_binds, text, _type_from_args ) @@ -29,6 +29,29 @@ class GenericFunction(Function): self.type = sqltypes.to_instance( type_ or getattr(self, '__return_type__', None)) + +class next_value(Function): + """Represent the 'next value', given a :class:`.Sequence` + as it's single argument. + + Compiles into the appropriate function on each backend, + or will raise NotImplementedError if used on a backend + that does not provide support for sequences. + + """ + type = sqltypes.Integer() + name = "next_value" + + def __init__(self, seq, **kw): + assert isinstance(seq, schema.Sequence), \ + "next_value() accepts a Sequence object as input." + self._bind = kw.get('bind', None) + self.sequence = seq + + @property + def _from_objects(self): + return [] + class AnsiFunction(GenericFunction): def __init__(self, **kwargs): GenericFunction.__init__(self, **kwargs) @@ -52,6 +75,7 @@ class min(ReturnTypeFromArgs): class sum(ReturnTypeFromArgs): pass + class now(GenericFunction): __return_type__ = sqltypes.DateTime diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index eeb4ebbcb8..79243ad4f6 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -10,7 +10,7 @@ from sqlalchemy.types import TypeDecorator from test.lib.schema import Table, Column from test.lib.testing import eq_ from test.sql import _base - +from sqlalchemy.dialects import sqlite class DefaultTest(testing.TestBase): @@ -303,6 +303,31 @@ class DefaultTest(testing.TestBase): (53, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today, 'py')]) + def test_no_embed_in_sql(self): + """Using a DefaultGenerator, Sequence, DefaultClause + in the columns, where clause of a select, or in the values + clause of insert, update, raises an informative error""" + for const in ( + sa.Sequence('y'), + sa.ColumnDefault('y'), + sa.DefaultClause('y') + ): + assert_raises_message( + sa.exc.ArgumentError, + "SQL expression object or string expected.", + t.select, [const] + ) + assert_raises_message( + sa.exc.InvalidRequestError, + "cannot be used directly as a column expression.", + str, t.insert().values(col4=const) + ) + assert_raises_message( + sa.exc.InvalidRequestError, + "cannot be used directly as a column expression.", + str, t.update().values(col4=const) + ) + def test_missing_many_param(self): assert_raises_message(exc.StatementError, "A value is required for bind parameter 'col7', in parameter group 1", @@ -565,6 +590,132 @@ class SequenceDDLTest(testing.TestBase, testing.AssertsCompiledSQL): "DROP SEQUENCE foo_seq", ) +class SequenceExecTest(testing.TestBase): + __requires__ = ('sequences',) + + @classmethod + def setup_class(cls): + cls.seq= Sequence("my_sequence") + cls.seq.create(testing.db) + + @classmethod + def teardown_class(cls): + cls.seq.drop(testing.db) + + def _assert_seq_result(self, ret): + """asserts return of next_value is an int""" + assert isinstance(ret, (int, long)) + assert ret > 0 + + def test_implicit_connectionless(self): + s = Sequence("my_sequence", metadata=MetaData(testing.db)) + self._assert_seq_result(s.execute()) + + def test_explicit(self): + s = Sequence("my_sequence") + self._assert_seq_result(s.execute(testing.db)) + + def test_explicit_optional(self): + """test dialect executes a Sequence, returns nextval, whether + or not "optional" is set """ + + s = Sequence("my_sequence", optional=True) + self._assert_seq_result(s.execute(testing.db)) + + def test_func_implicit_connectionless_execute(self): + """test func.next_value().execute()/.scalar() works + with connectionless execution. """ + + s = Sequence("my_sequence", metadata=MetaData(testing.db)) + self._assert_seq_result(s.next_value().execute().scalar()) + + def test_func_explicit(self): + s = Sequence("my_sequence") + self._assert_seq_result(testing.db.scalar(s.next_value())) + + def test_func_implicit_connectionless_scalar(self): + """test func.next_value().execute()/.scalar() works. """ + + s = Sequence("my_sequence", metadata=MetaData(testing.db)) + self._assert_seq_result(s.next_value().scalar()) + + def test_func_embedded_select(self): + """test can use next_value() in select column expr""" + + s = Sequence("my_sequence") + self._assert_seq_result( + testing.db.scalar(select([s.next_value()])) + ) + + @testing.fails_on('oracle', "ORA-02287: sequence number not allowed here") + @testing.provide_metadata + def test_func_embedded_whereclause(self): + """test can use next_value() in whereclause""" + + t1 = Table('t', metadata, + Column('x', Integer) + ) + t1.create(testing.db) + testing.db.execute(t1.insert(), [{'x':1}, {'x':300}, {'x':301}]) + s = Sequence("my_sequence") + eq_( + testing.db.execute( + t1.select().where(t1.c.x > s.next_value()) + ).fetchall(), + [(300, ), (301, )] + ) + + @testing.provide_metadata + def test_func_embedded_valuesbase(self): + """test can use next_value() in values() of _ValuesBase""" + + t1 = Table('t', metadata, + Column('x', Integer) + ) + t1.create(testing.db) + s = Sequence("my_sequence") + testing.db.execute( + t1.insert().values(x=s.next_value()) + ) + self._assert_seq_result( + testing.db.scalar(t1.select()) + ) + + @testing.provide_metadata + def test_inserted_pk_no_returning(self): + """test inserted_primary_key contains [None] when + pk_col=next_value(), implicit returning is not used.""" + + e = engines.testing_engine(options={'implicit_returning':False}) + s = Sequence("my_sequence") + metadata.bind = e + t1 = Table('t', metadata, + Column('x', Integer, primary_key=True) + ) + t1.create() + r = e.execute( + t1.insert().values(x=s.next_value()) + ) + eq_(r.inserted_primary_key, [None]) + + @testing.requires.returning + @testing.provide_metadata + def test_inserted_pk_implicit_returning(self): + """test inserted_primary_key contains the result when + pk_col=next_value(), when implicit returning is used.""" + + e = engines.testing_engine(options={'implicit_returning':True}) + s = Sequence("my_sequence") + metadata.bind = e + t1 = Table('t', metadata, + Column('x', Integer, primary_key=True) + ) + t1.create() + r = e.execute( + t1.insert().values(x=s.next_value()) + ) + self._assert_seq_result(r.inserted_primary_key[0]) + class SequenceTest(testing.TestBase, testing.AssertsCompiledSQL): __requires__ = ('sequences',) @@ -586,29 +737,36 @@ class SequenceTest(testing.TestBase, testing.AssertsCompiledSQL): finally: seq.drop(testing.db) - @testing.fails_on('maxdb', 'FIXME: unknown') - # maxdb db-api seems to double-execute NEXTVAL - # internally somewhere, - # throwing off the numbers for these tests... - @testing.provide_metadata - def test_implicit_sequence_exec(self): - s = Sequence("my_sequence", metadata=metadata) - metadata.create_all() - x = s.execute() - eq_(x, 1) def _has_sequence(self, name): return testing.db.dialect.has_sequence(testing.db, name) - @testing.fails_on('maxdb', 'FIXME: unknown') - def teststandalone_explicit(self): - s = Sequence("my_sequence") - s.create(bind=testing.db) - try: - x = s.execute(testing.db) - eq_(x, 1) - finally: - s.drop(testing.db) + def test_nextval_render(self): + """test dialect renders the "nextval" construct, + whether or not "optional" is set """ + + for s in ( + Sequence("my_seq"), + Sequence("my_seq", optional=True)): + assert str(s.next_value(). + compile(dialect=testing.db.dialect)) in ( + "nextval('my_seq')", + "gen_id(my_seq, 1)", + "my_seq.nextval", + ) + + def test_nextval_unsupported(self): + """test next_value() used on non-sequence platform + raises NotImplementedError.""" + + s = Sequence("my_seq") + d = sqlite.dialect() + assert_raises_message( + NotImplementedError, + "Dialect 'sqlite' does not support sequence increments.", + s.next_value().compile, + dialect=d + ) def test_checkfirst_sequence(self): s = Sequence("my_sequence") @@ -733,10 +891,10 @@ class TableBoundSequenceTest(testing.TestBase): class SpecialTypePKTest(testing.TestBase): """test process_result_value in conjunction with primary key columns. - + Also tests that "autoincrement" checks are against column.type._type_affinity, rather than the class of "type" itself. - + """ @classmethod @@ -818,12 +976,12 @@ class ServerDefaultsOnPKTest(testing.TestBase): @testing.provide_metadata def test_string_default_none_on_insert(self): """Test that without implicit returning, we return None for - a string server default. - + a string server default. + That is, we don't want to attempt to pre-execute "server_default" generically - the user should use a Python side-default for a case like this. Testing that all backends do the same thing here. - + """ t = Table('x', metadata, Column('y', String(10), server_default='key_one', primary_key=True), @@ -946,7 +1104,7 @@ class UnicodeDefaultsTest(testing.TestBase): # end Py2K c = Column(Unicode(32), default=default) - + def test_nonunicode_default(self): # Py3K #default = b'foo'