]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added support for dialects that have both sequences and autoincrementing PKs.
authorJason Kirtland <jek@discorporate.us>
Tue, 23 Oct 2007 01:47:21 +0000 (01:47 +0000)
committerJason Kirtland <jek@discorporate.us>
Tue, 23 Oct 2007 01:47:21 +0000 (01:47 +0000)
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
test/sql/defaults.py

index 55eb9a45be3c5edcc1a38ffc457a5378f3cb43d1..5316f528dc7a03cdc6982f5f87bae5b54c282246 100644 (file)
@@ -142,7 +142,8 @@ class FBDialect(default.DefaultDialect):
     supports_sane_rowcount = False
     supports_sane_multi_rowcount = False
     max_identifier_length = 31
-    preexecute_sequences = True
+    preexecute_pk_sequences = True
+    supports_pk_autoincrement = False
 
     def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
         default.DefaultDialect.__init__(self, **kwargs)
index 2c5eacdbdeb1249ed3287bfd7a76a41a55a78a41..f4f8aa689f9e47ccef90fa931e7228e62b456761 100644 (file)
@@ -237,7 +237,8 @@ class OracleDialect(default.DefaultDialect):
     max_identifier_length = 30
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = False
-    preexecute_sequences = True
+    preexecute_pk_sequences = True
+    supports_pk_autoincrement = False
 
     def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, **kwargs):
         default.DefaultDialect.__init__(self, default_paramstyle='named', **kwargs)
index ddf6a6b9ce8aaf5105f57adbf668d47e6ac8ff98..018074b67755a3361d7a6d6768f457266c87831f 100644 (file)
@@ -279,7 +279,8 @@ class PGDialect(default.DefaultDialect):
     max_identifier_length = 63
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = False
-    preexecute_sequences = True
+    preexecute_pk_sequences = True
+    supports_pk_autoincrement = False
 
     def __init__(self, use_oids=False, server_side_cursors=False, **kwargs):
         default.DefaultDialect.__init__(self, default_paramstyle='pyformat', **kwargs)
index 5f39756849f8f6586160a068b06352aa0292b137..131f50540443dd722351cddf300ab5c5e03b1dfb 100644 (file)
@@ -12,9 +12,9 @@ higher-level statement-construction, connection-management, execution
 and result contexts.
 """
 
+import StringIO, sys
 from sqlalchemy import exceptions, schema, util, types, logging
 from sqlalchemy.sql import expression, visitors
-import StringIO, sys
 
 
 class Dialect(object):
@@ -79,9 +79,14 @@ class Dialect(object):
       Indicate whether the dialect properly implements rowcount for ``UPDATE`` and ``DELETE`` statements
       when executed via executemany.
 
-    preexecute_sequences
-      Indicate if the dialect should pre-execute sequences on primary key columns during an INSERT,
-      if it's desired that the new row's primary key be available after execution.
+    preexecute_pk_sequences
+      Indicate if the dialect should pre-execute sequences on primary key
+      columns during an INSERT, if it's desired that the new row's primary key
+      be available after execution.
+
+    supports_pk_autoincrement
+      Indicates if the dialect should allow the database to passively assign
+      a primary key column value.
     """
 
     def create_connect_args(self, url):
index 1a15c8b8d1754008cfba1d4105aa1f43422b662c..d826b97fadabf9ec0a5fbf9711fb07203628a68d 100644 (file)
@@ -31,7 +31,8 @@ class DefaultDialect(base.Dialect):
     max_identifier_length = 9999
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = True
-    preexecute_sequences = False
+    preexecute_pk_sequences = False
+    supports_pk_autoincrement = True
 
     def __init__(self, convert_unicode=False, encoding='utf-8', default_paramstyle='named', paramstyle=None, dbapi=None, **kwargs):
         self.convert_unicode = convert_unicode
@@ -47,7 +48,17 @@ class DefaultDialect(base.Dialect):
             self.paramstyle = default_paramstyle
         self.positional = self.paramstyle in ('qmark', 'format', 'numeric')
         self.identifier_preparer = self.preparer(self)
-    
+
+        # preexecute_sequences was renamed preexecute_pk_sequences.  If a
+        # subclass has the older property, proxy the new name to the subclass's
+        # property.
+        # TODO: remove @ 0.5.0
+        if (hasattr(self, 'preexecute_sequences') and
+            isinstance(getattr(type(self), 'preexecute_pk_sequences'), bool)):
+            setattr(type(self), 'preexecute_pk_sequences',
+                    property(lambda s: s.preexecute_sequences, doc=(
+                      "Proxy to deprecated preexecute_sequences attribute.")))
+
     def dbapi_type_map(self):
         # most DB-APIs have problems with this (such as, psycocpg2 types 
         # are unhashable).  So far Oracle can return it.
index 5572c2ed4eff29687aa3e3ddefc246a3779c36a9..f2627eb85f0defd010a956bb9d10a79aed3cf0ae 100644 (file)
@@ -674,9 +674,15 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
                 values.append((c, value))
             elif isinstance(c, schema.Column):
                 if self.isinsert:
-                    if c.primary_key and self.dialect.preexecute_sequences and not self.inline:
-                        values.append((c, create_bind_param(c, None)))
-                        self.prefetch.add(c)
+                    if (c.primary_key and self.dialect.preexecute_pk_sequences
+                        and not self.inline):
+                        if (((isinstance(c.default, schema.Sequence) and
+                              not c.default.optional) or
+                             not self.dialect.supports_pk_autoincrement) or
+                            (c.default is not None and 
+                             not isinstance(c.default, schema.Sequence))):
+                            values.append((c, create_bind_param(c, None)))
+                            self.prefetch.add(c)
                     elif isinstance(c.default, schema.ColumnDefault):
                         if isinstance(c.default.arg, sql.ClauseElement):
                             values.append((c, self.process(c.default.arg.self_group())))
index a34dc303e5b886a6b6f62ecda239d50db4e7a4fd..c29ffa3b396d416ffef78be702928afac292e0ad 100644 (file)
@@ -1,9 +1,9 @@
 import testbase
+import datetime
 from sqlalchemy import *
 from sqlalchemy import exceptions, schema, util
 from sqlalchemy.orm import mapper, create_session
 from testlib import *
-import datetime
 
 class DefaultTest(PersistTest):
 
@@ -31,12 +31,13 @@ class DefaultTest(PersistTest):
                 # since its a "branched" connection
                 conn.close()
             
-        use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle'
-        is_oracle = db.engine.name == 'oracle'
+        use_function_defaults = testing.against('postgres', 'oracle')
+        is_oracle = testing.against('oracle')
  
         # select "count(1)" returns different results on different DBs
         # also correct for "current_date" compatible as column default, value differences
-        currenttime = func.current_date(type_=Date, bind=db);
+        currenttime = func.current_date(type_=Date, bind=db)
+
         if is_oracle:
             ts = db.func.trunc(func.sysdate(), literal_column("'DAY'")).scalar()
             f = select([func.length('abcdef')], bind=db).scalar()
@@ -50,7 +51,10 @@ class DefaultTest(PersistTest):
             f = select([func.length('abcdef')], bind=db).scalar()
             f2 = select([func.length('abcdefghijk')], bind=db).scalar()
             def1 = currenttime
-            def2 = text("current_date")
+            if testing.against('maxdb'):
+                def2 = text("curdate")
+            else:
+                def2 = text("current_date")
             deftype = Date
             ts = db.func.current_date().scalar()
         else:
@@ -153,7 +157,7 @@ class DefaultTest(PersistTest):
 
     def testinsertmany(self):
         # MySQL-Python 1.2.2 breaks functions in execute_many :(
-        if (testbase.db.name == 'mysql' and
+        if (testing.against('mysql') and
             testbase.db.dialect.dbapi.version_info[:3] == (1, 2, 2)):
             return
 
@@ -171,7 +175,7 @@ class DefaultTest(PersistTest):
         
     def testupdatemany(self):
         # MySQL-Python 1.2.2 breaks functions in execute_many :(
-        if (testbase.db.name == 'mysql' and
+        if (testing.against('mysql') and
             testbase.db.dialect.dbapi.version_info[:3] == (1, 2, 2)):
             return
 
@@ -254,7 +258,7 @@ class AutoIncrementTest(PersistTest):
     def tearDown(self):
         aimeta.drop_all()
 
-    @testing.supported('postgres', 'mysql')
+    @testing.supported('postgres', 'mysql', 'maxdb')
     def testnonautoincrement(self):
         meta = MetaData(testbase.db)
         nonai_table = Table("nonaitest", meta, 
@@ -326,9 +330,29 @@ class AutoIncrementTest(PersistTest):
         finally:
             con.close()
 
+    def test_autoincrement_fk(self):
+        if not testbase.db.dialect.supports_pk_autoincrement:
+            return True
+        
+        metadata = MetaData(testbase.db)
+
+        # No optional sequence here.
+        nodes = Table('nodes', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('parent_id', Integer, ForeignKey('nodes.id')),
+            Column('data', String(30)))
+        metadata.create_all()
+        try:
+            r = nodes.insert().execute(data='foo')
+            id_ = r.last_inserted_ids()[0]
+            nodes.insert().execute(data='bar', parent_id=id_)
+        finally:
+            metadata.drop_all()
+
+
 
 class SequenceTest(PersistTest):
-    @testing.supported('postgres', 'oracle')
+    @testing.supported('postgres', 'oracle', 'maxdb')
     def setUpAll(self):
         global cartitems, sometable, metadata
         metadata = MetaData(testbase.db)
@@ -338,16 +362,17 @@ class SequenceTest(PersistTest):
             Column("createdate", DateTime())
         )
         sometable = Table( 'Manager', metadata,
-               Column( 'obj_id', Integer, Sequence('obj_id_seq'), ),
-               Column( 'name', String, ),
-               Column( 'id', Integer, Sequence('Manager_id_seq', optional=True), primary_key=True),
+               Column('obj_id', Integer, Sequence('obj_id_seq'), ),
+               Column('name', String, ),
+               Column('id', Integer, Sequence('Manager_id_seq', optional=True), primary_key=True),
            )
         
         metadata.create_all()
     
-    @testing.supported('postgres', 'oracle')
+    @testing.supported('postgres', 'oracle', 'maxdb')
     def testseqnonpk(self):
         """test sequences fire off as defaults on non-pk columns"""
+
         sometable.insert().execute(name="somename")
         sometable.insert().execute(name="someother")
         sometable.insert().execute(
@@ -360,17 +385,26 @@ class SequenceTest(PersistTest):
             (3, "name3", 3),
             (4, "name4", 4),
         ]
-        
-    @testing.supported('postgres', 'oracle')
+
+    @testing.supported('postgres', 'oracle', 'maxdb')
     def testsequence(self):
         cartitems.insert().execute(description='hi')
         cartitems.insert().execute(description='there')
-        cartitems.insert().execute(description='lala')
+        r = cartitems.insert().execute(description='lala')
+
+        assert r.last_inserted_ids() and r.last_inserted_ids()[0] is not None
+        id_ = r.last_inserted_ids()[0]
+
+        assert select([func.count(cartitems.c.cart_id)],
+                      and_(cartitems.c.description == 'lala',
+                           cartitems.c.cart_id == id_)).scalar() == 1
         
         cartitems.select().execute().fetchall()
-   
+
    
     @testing.supported('postgres', 'oracle')
+    # maxdb db-api seems to double-execute NEXTVAL internally somewhere,
+    # throwing off the numbers for these tests...
     def test_implicit_sequence_exec(self):
         s = Sequence("my_sequence", metadata=MetaData(testbase.db))
         s.create()
@@ -390,7 +424,7 @@ class SequenceTest(PersistTest):
         finally:
             s.drop(testbase.db)
     
-    @testing.supported('postgres', 'oracle')
+    @testing.supported('postgres', 'oracle', 'maxdb')
     def test_checkfirst(self):
         s = Sequence("my_sequence")
         s.create(testbase.db, checkfirst=False)
@@ -403,7 +437,7 @@ class SequenceTest(PersistTest):
         x = cartitems.c.cart_id.sequence.execute()
         self.assert_(1 <= x <= 4)
         
-    @testing.supported('postgres', 'oracle')
+    @testing.supported('postgres', 'oracle', 'maxdb')
     def tearDownAll(self): 
         metadata.drop_all()