]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- getting slightly more consistent behavior for the edge case of pk columns
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 15 Jan 2011 21:42:29 +0000 (16:42 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 15 Jan 2011 21:42:29 +0000 (16:42 -0500)
with server default - autoincrement is now false with any server_default,
so these all return None, applies consistency to [ticket:2020], [ticket:2021].
if prefetch is desired a "default" should be used instead of server_default.

doc/build/core/expression_api.rst
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/reflection.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/sql/test_defaults.py

index e907b25352f17d943655d170572df125a72ad0b3..88c0840ac0c10aeb55e93f3defd978370b858416 100644 (file)
@@ -185,7 +185,7 @@ Classes
    :show-inheritance:
 
 .. autoclass:: Insert
-   :members: prefix_with, values
+   :members: prefix_with, values, returning
    :show-inheritance:
 
 .. autoclass:: Join
index 7ed4ca07e2096122cc441b6f388e20c4f2bcd196..63ad37ce973a7d805fa0f48ecc509045add31064 100644 (file)
@@ -594,9 +594,9 @@ class OracleIdentifierPreparer(compiler.IdentifierPreparer):
 
 class OracleExecutionContext(default.DefaultExecutionContext):
     def fire_sequence(self, seq, type_):
-        return int(self._execute_scalar("SELECT " + 
+        return self._execute_scalar("SELECT " + 
                     self.dialect.identifier_preparer.format_sequence(seq) + 
-                    ".nextval FROM DUAL"), type_)
+                    ".nextval FROM DUAL", type_)
 
 class OracleDialect(default.DefaultDialect):
     name = 'oracle'
index 7c712e8aa36eddeaec89c5949538d96f8e3b9360..a8fb4e51a7bf2b83d90fb9dc6b82c491c909a37b 100644 (file)
@@ -508,13 +508,15 @@ class PGDDLCompiler(compiler.DDLCompiler):
         colspec = self.preparer.format_column(column)
         type_affinity = column.type._type_affinity
         if column.primary_key and \
-            len(column.foreign_keys)==0 and \
-            column.autoincrement and \
-            issubclass(type_affinity, sqltypes.Integer) and \
+            column is column.table._autoincrement_column and \
             not issubclass(type_affinity, sqltypes.SmallInteger) and \
-            (column.default is None or 
-                (isinstance(column.default, schema.Sequence) and
-                column.default.optional)):
+            (
+                column.default is None or 
+                (
+                    isinstance(column.default, schema.Sequence) and
+                    column.default.optional
+                )
+            ):
             if issubclass(type_affinity, sqltypes.BigInteger):
                 colspec += " BIGSERIAL"
             else:
@@ -689,7 +691,7 @@ class PGExecutionContext(default.DefaultExecutionContext):
             return None
 
     def get_insert_default(self, column):
-        if column.primary_key:
+        if column.primary_key and column is column.table._autoincrement_column:
             if (isinstance(column.server_default, schema.DefaultClause) and
                 column.server_default.arg is not None):
 
@@ -697,8 +699,7 @@ class PGExecutionContext(default.DefaultExecutionContext):
                 return self._execute_scalar("select %s" %
                                         column.server_default.arg, column.type)
 
-            elif column is column.table._autoincrement_column \
-                    and (column.default is None or 
+            elif (column.default is None or 
                         (isinstance(column.default, schema.Sequence) and
                         column.default.optional)):
 
index 3bdcad2ac8ca23a3f72f6969eabc40b0704b0928..9eb1b8b40068a12dec33f598ce31d3f5bc2219d0 100644 (file)
@@ -2460,9 +2460,23 @@ class ResultProxy(object):
     @util.memoized_property
     def inserted_primary_key(self):
         """Return the primary key for the row just inserted.
-
-        This only applies to single row insert() constructs which
-        did not explicitly specify returning().
+        
+        The return value is a list of scalar values 
+        corresponding to the list of primary key columns
+        in the target table.
+
+        This only applies to single row :func:`.insert` 
+        constructs which did not explicitly specify 
+        :meth:`.Insert.returning`.
+        
+        Note that primary key columns which specify a
+        server_default clause, 
+        or otherwise do not qualify as "autoincrement"
+        columns (see the notes at :class:`.Column`), and were
+        generated using the database-side default, will
+        appear in this list as ``None`` unless the backend 
+        supports "returning" and the insert statement executed
+        with the "implicit returning" enabled.
 
         """
 
index e21ec1c4070411ab79340015f1b95c356d07dc2b..da6ed12a6d2735b97cd92afd0cbca329d8e5c38a 100644 (file)
@@ -101,7 +101,7 @@ class DefaultDialect(base.Dialect):
 
         if not getattr(self, 'ported_sqla_06', True):
             util.warn(
-                "The %s dialect is not yet ported to SQLAlchemy 0.6" %
+                "The %s dialect is not yet ported to SQLAlchemy 0.6/0.7" %
                 self.name)
 
         self.convert_unicode = convert_unicode
@@ -625,7 +625,8 @@ class DefaultExecutionContext(base.ExecutionContext):
         return self.dialect.supports_sane_multi_rowcount
 
     def post_insert(self):
-        if self.dialect.postfetch_lastrowid and \
+        if not self._is_implicit_returning and \
+            self.dialect.postfetch_lastrowid and \
             (not self.inserted_primary_key or \
                         None in self.inserted_primary_key):
 
index cf254cba607c3d6ec0bdcee1d005f56ebdf11231..00b2fd1bf18c051fd0d1db51b6612a247abf9ee2 100644 (file)
@@ -398,7 +398,11 @@ class Inspector(object):
             if col_d.get('default') is not None:
                 # the "default" value is assumed to be a literal SQL expression,
                 # so is wrapped in text() so that no quoting occurs on re-issuance.
-                colargs.append(sa_schema.DefaultClause(sql.text(col_d['default'])))
+                colargs.append(
+                    sa_schema.DefaultClause(
+                        sql.text(col_d['default']), _reflected=True
+                    )
+                )
 
             if 'sequence' in col_d:
                 # TODO: mssql, maxdb and sybase are using this.
index a530a1a7a781e54be1b1b3f30374d0457424b96b..26f607512634dd99eef2ee235f7101f045fb9c4f 100644 (file)
@@ -325,11 +325,8 @@ class Table(SchemaItem, expression.TableClause):
             if col.autoincrement and \
                 issubclass(col.type._type_affinity, types.Integer) and \
                 not col.foreign_keys and \
-                isinstance(col.default, (type(None), Sequence)):
-                # don't look at server_default here since different backends may
-                # or may not have a server_default, e.g. postgresql reflected
-                # SERIAL cols will have a DefaultClause here but are still
-                # autoincrement. 
+                isinstance(col.default, (type(None), Sequence)) and \
+                (col.server_default is None or col.server_default.reflected):
                 return col
 
     @property
@@ -1231,6 +1228,7 @@ class DefaultGenerator(SchemaItem):
     __visit_name__ = 'default_generator'
 
     is_sequence = False
+    is_server_default = False
 
     def __init__(self, for_update=False):
         self.for_update = for_update
@@ -1423,6 +1421,8 @@ class FetchedValue(object):
     INSERT.
 
     """
+    is_server_default = True
+    reflected = False
 
     def __init__(self, for_update=False):
         self.for_update = for_update
@@ -1460,12 +1460,13 @@ class DefaultClause(FetchedValue):
 
     """
 
-    def __init__(self, arg, for_update=False):
+    def __init__(self, arg, for_update=False, _reflected=False):
         util.assert_arg_type(arg, (basestring,
                                    expression.ClauseElement,
                                    expression._TextClause), 'arg')
         super(DefaultClause, self).__init__(for_update)
         self.arg = arg
+        self.reflected = _reflected
 
     def __repr__(self):
         return "DefaultClause(%r, for_update=%r)" % \
index ce98dfb83cef40009ff6aea647565ce67a1c43e2..d906bf5d46acd8001b7c49fdca6efb99a9006bda 100644 (file)
@@ -1102,18 +1102,16 @@ class SQLCompiler(engine.Compiled):
                         else:
                             self.returning.append(c)
                     else:
-                        if (
-                            c.default is not None and \
-                                (
-                                    self.dialect.supports_sequences or 
-                                    not c.default.is_sequence
-                                )
-                            ) or \
-                             self.dialect.preexecute_autoincrement_sequences:
+                        if c.default is not None or \
+                            c is stmt.table._autoincrement_column and (
+                                self.dialect.supports_sequences or
+                                self.dialect.preexecute_autoincrement_sequences
+                            ):
 
                             values.append(
                                 (c, self._create_crud_bind_param(c, None))
                             )
+
                             self.prefetch.append(c)
 
                 elif c.default is not None:
index ede194f7ccfcf4920b78d29d2b8fd2fa6d95ddf7..6a368b8c05d6c5607afadddb0ad909d6f8ccdc33 100644 (file)
@@ -4436,7 +4436,7 @@ class Select(_SelectBase):
         self._bind = bind
     bind = property(bind, _set_bind)
 
-class _UpdateBase(Executable, ClauseElement):
+class UpdateBase(Executable, ClauseElement):
     """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements."""
 
     __visit_name__ = 'update_base'
@@ -4513,7 +4513,8 @@ class _UpdateBase(Executable, ClauseElement):
         """
         self._returning = cols
 
-class _ValuesBase(_UpdateBase):
+class ValuesBase(UpdateBase):
+    """Supplies support for :meth:`.ValuesBase.values` to INSERT and UPDATE constructs."""
 
     __visit_name__ = 'values_base'
 
@@ -4548,7 +4549,7 @@ class _ValuesBase(_UpdateBase):
             self.parameters.update(self._process_colparams(v))
             self.parameters.update(kwargs)
 
-class Insert(_ValuesBase):
+class Insert(ValuesBase):
     """Represent an INSERT construct.
 
     The :class:`Insert` object is created using the :func:`insert()` function.
@@ -4566,7 +4567,7 @@ class Insert(_ValuesBase):
                 prefixes=None, 
                 returning=None,
                 **kwargs):
-        _ValuesBase.__init__(self, table, values)
+        ValuesBase.__init__(self, table, values)
         self._bind = bind
         self.select = None
         self.inline = inline
@@ -4598,7 +4599,7 @@ class Insert(_ValuesBase):
         clause = _literal_as_text(clause)
         self._prefixes = self._prefixes + (clause,)
 
-class Update(_ValuesBase):
+class Update(ValuesBase):
     """Represent an Update construct.
 
     The :class:`Update` object is created using the :func:`update()` function.
@@ -4614,7 +4615,7 @@ class Update(_ValuesBase):
                 bind=None, 
                 returning=None,
                 **kwargs):
-        _ValuesBase.__init__(self, table, values)
+        ValuesBase.__init__(self, table, values)
         self._bind = bind
         self._returning = returning
         if whereclause is not None:
@@ -4650,7 +4651,7 @@ class Update(_ValuesBase):
             self._whereclause = _literal_as_text(whereclause)
 
 
-class Delete(_UpdateBase):
+class Delete(UpdateBase):
     """Represent a DELETE construct.
 
     The :class:`Delete` object is created using the :func:`delete()` function.
index 7822e487ccd2f39cca0564267be1ba754a425934..0d099a78686fc82a1a585b9650dbdcb21edf6a13 100644 (file)
@@ -5,9 +5,9 @@ from sqlalchemy.sql import select, text, literal_column
 import sqlalchemy as sa
 from test.lib import testing, engines
 from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean, exc,\
-                Sequence, Column, func, literal
+                Sequence, func, literal
 from sqlalchemy.types import TypeDecorator
-from test.lib.schema import Table
+from test.lib.schema import Table, Column
 from test.lib.testing import eq_
 from test.sql import _base
 
@@ -704,9 +704,13 @@ class SpecialTypePKTest(testing.TestBase):
         class MyInteger(TypeDecorator):
             impl = Integer
             def process_bind_param(self, value, dialect):
+                if value is None:
+                    return None
                 return int(value[4:])
 
             def process_result_value(self, value, dialect):
+                if value is None:
+                    return None
                 return "INT_%d" % value
 
         cls.MyInteger = MyInteger
@@ -715,6 +719,8 @@ class SpecialTypePKTest(testing.TestBase):
     def _run_test(self, *arg, **kw):
         implicit_returning = kw.pop('implicit_returning', True)
         kw['primary_key'] = True
+        if kw.get('autoincrement', True):
+            kw['test_needs_autoincrement'] = True
         t = Table('x', metadata,
             Column('y', self.MyInteger, *arg, **kw),
             Column('data', Integer),
@@ -723,7 +729,12 @@ class SpecialTypePKTest(testing.TestBase):
 
         t.create()
         r = t.insert().values(data=5).execute()
-        eq_(r.inserted_primary_key, ['INT_1'])
+
+        # we don't pre-fetch 'server_default'.
+        if 'server_default' in kw and (not testing.db.dialect.implicit_returning or not implicit_returning):
+            eq_(r.inserted_primary_key, [None])
+        else:
+            eq_(r.inserted_primary_key, ['INT_1'])
         r.close()
 
         eq_(
@@ -745,13 +756,9 @@ class SpecialTypePKTest(testing.TestBase):
     def test_sequence(self):
         self._run_test(Sequence('foo_seq'))
 
-    @testing.fails_on('mysql', "Pending [ticket:2021]")
     def test_server_default(self):
-        # note that the MySQL dialect has to not render AUTOINCREMENT on this one
         self._run_test(server_default='1',)
 
-    @testing.fails_on('mysql', "Pending [ticket:2021]")
-    @testing.fails_on('sqlite', "Pending [ticket:2021]")
     def test_server_default_no_autoincrement(self):
         self._run_test(server_default='1', autoincrement=False)
 
@@ -767,4 +774,128 @@ class SpecialTypePKTest(testing.TestBase):
     def test_server_default_no_implicit_returning(self):
         self._run_test(server_default='1', autoincrement=False)
 
+class ServerDefaultsOnPKTest(testing.TestBase):
+    @testing.provide_metadata
+    def test_string_default_none_on_insert(self):
+        """Test that without implicit returning, we return None for 
+        a string server default.  
+        
+        That is, we don't want to attempt to pre-execute "server_default"
+        generically - the user should use a Python side-default for a case
+        like this.   Testing that all backends do the same thing here.
+        
+        """
+        t = Table('x', metadata, 
+                Column('y', String(10), server_default='key_one', primary_key=True),
+                Column('data', String(10)),
+                implicit_returning=False
+                )
+        metadata.create_all()
+        r = t.insert().execute(data='data')
+        eq_(r.inserted_primary_key, [None])
+        eq_(
+            t.select().execute().fetchall(),
+            [('key_one', 'data')]
+        )
+
+    @testing.requires.returning
+    @testing.provide_metadata
+    def test_string_default_on_insert_with_returning(self):
+        """With implicit_returning, we get a string PK default back no problem."""
+        t = Table('x', metadata, 
+                Column('y', String(10), server_default='key_one', primary_key=True),
+                Column('data', String(10))
+                )
+        metadata.create_all()
+        r = t.insert().execute(data='data')
+        eq_(r.inserted_primary_key, ['key_one'])
+        eq_(
+            t.select().execute().fetchall(),
+            [('key_one', 'data')]
+        )
+
+    @testing.provide_metadata
+    def test_int_default_none_on_insert(self):
+        t = Table('x', metadata, 
+                Column('y', Integer, 
+                        server_default='5', primary_key=True),
+                Column('data', String(10)),
+                implicit_returning=False
+                )
+        assert t._autoincrement_column is None
+        metadata.create_all()
+        r = t.insert().execute(data='data')
+        eq_(r.inserted_primary_key, [None])
+        if testing.against('sqlite'):
+            eq_(
+                t.select().execute().fetchall(),
+                [(1, 'data')]
+            )
+        else:
+            eq_(
+                t.select().execute().fetchall(),
+                [(5, 'data')]
+            )
+    @testing.fails_on('firebird', "col comes back as autoincrement")
+    @testing.fails_on('sqlite', "col comes back as autoincrement")
+    @testing.fails_on('oracle', "col comes back as autoincrement")
+    @testing.provide_metadata
+    def test_autoincrement_reflected_from_server_default(self):
+        t = Table('x', metadata, 
+                Column('y', Integer, 
+                        server_default='5', primary_key=True),
+                Column('data', String(10)),
+                implicit_returning=False
+                )
+        assert t._autoincrement_column is None
+        metadata.create_all()
+
+        m2 = MetaData(metadata.bind)
+        t2 = Table('x', m2, autoload=True, implicit_returning=False)
+        assert t2._autoincrement_column is None
+
+    @testing.fails_on('firebird', "attempts to insert None")
+    @testing.fails_on('sqlite', "returns a value")
+    @testing.provide_metadata
+    def test_int_default_none_on_insert_reflected(self):
+        t = Table('x', metadata, 
+                Column('y', Integer, 
+                        server_default='5', primary_key=True),
+                Column('data', String(10)),
+                implicit_returning=False
+                )
+        metadata.create_all()
+
+        m2 = MetaData(metadata.bind)
+        t2 = Table('x', m2, autoload=True, implicit_returning=False)
+
+        r = t2.insert().execute(data='data')
+        eq_(r.inserted_primary_key, [None])
+        if testing.against('sqlite'):
+            eq_(
+                t2.select().execute().fetchall(),
+                [(1, 'data')]
+            )
+        else:
+            eq_(
+                t2.select().execute().fetchall(),
+                [(5, 'data')]
+            )
+
+    @testing.requires.returning
+    @testing.provide_metadata
+    def test_int_default_on_insert_with_returning(self):
+        t = Table('x', metadata, 
+                Column('y', Integer, 
+                        server_default='5', primary_key=True),
+                Column('data', String(10))
+                )
+
+        metadata.create_all()
+        r = t.insert().execute(data='data')
+        eq_(r.inserted_primary_key, [5])
+        eq_(
+            t.select().execute().fetchall(),
+            [(5, 'data')]
+        )