]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
postgres oids say byebye by default, putting hooks in for engines to determine column...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 1 Jan 2006 20:30:53 +0000 (20:30 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 1 Jan 2006 20:30:53 +0000 (20:30 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/mapping/mapper.py
test/objectstore.py
test/testbase.py

index a06abe598740e15ef1e911163ac0c97aa4ba45c5..bcd20349b4d60c7405203a2980934de70a38060c 100644 (file)
@@ -370,18 +370,25 @@ class ANSICompiler(sql.Compiled):
         """called when visiting an Insert statement, for each column in the table that
         contains a Sequence object."""
         pass
+    
+    def visit_insert_column(selef, column):
+        """called when visiting an Insert statement, for each column in the table
+        that is a NULL insert into the table"""
+        pass
         
     def visit_insert(self, insert_stmt):
         # set up a call for the defaults and sequences inside the table
         class DefaultVisitor(schema.SchemaVisitor):
+            def visit_column(s, c):
+                self.visit_insert_column(c)
             def visit_column_default(s, cd):
                 self.visit_insert_column_default(c, cd)
             def visit_sequence(s, seq):
                 self.visit_insert_sequence(c, seq)
         vis = DefaultVisitor()
         for c in insert_stmt.table.c:
-            if (self.parameters is None or self.parameters.get(c.key, None) is None) and c.default is not None:
-                c.default.accept_visitor(vis)
+            if (self.parameters is None or self.parameters.get(c.key, None) is None):
+                c.accept_visitor(vis)
         
         self.isinsert = True
         colparams = self._get_colparams(insert_stmt)
index 0e9328a20299e8d22945efbbc3c06ca5ce731104..3ae83cb8c4001e93b9075a5e54d8a7a178ec9307 100644 (file)
@@ -158,6 +158,7 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
 
     def post_exec(self, proxy, compiled, parameters, **kwargs):
         if getattr(compiled, "isinsert", False) and self.context.last_inserted_ids is None:
+            raise "cant use cursor.lastrowid without OIDs enabled"
             table = compiled.statement.table
             cursor = proxy()
             if cursor.lastrowid is not None and table is not None and len(table.primary_key):
@@ -190,8 +191,12 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
 
 class PGCompiler(ansisql.ANSICompiler):
 
-    def visit_insert_sequence(self, column, sequence):
-        if self.parameters.get(column.key, None) is None and not sequence.optional:
+    def visit_insert_column(self, column):
+        # Postgres advises against OID usage and turns it off in 8.1,
+        # effectively making cursor.lastrowid
+        # useless, effectively making reliance upon SERIAL useless.  
+        # so all column primary key inserts must be explicitly present
+        if column.primary_key:
             self.parameters[column.key] = None
 
     def limit_clause(self, select):
@@ -232,6 +237,13 @@ class PGSchemaDropper(ansisql.ANSISchemaDropper):
             self.execute()
 
 class PGDefaultRunner(ansisql.ANSIDefaultRunner):
+    def get_column_default(self, column):
+        if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+            c = self.proxy("select nextval('%s_%s_seq')" % (column.table.name, column.name))
+            return c.fetchone()[0]
+        else:
+            return ansisql.ANSIDefaultRunner.get_column_default(self, column)
+    
     def visit_sequence(self, seq):
         if not seq.optional:
             c = self.proxy("select nextval('%s')" % seq.name)
index f237014247d0d387fa342f5155ee66f8c1e9d1ed..f3cc8efada45f2ae4fdc67e4a1ec9815c7a93520 100644 (file)
@@ -129,6 +129,12 @@ class DefaultRunner(schema.SchemaVisitor):
         self.proxy = proxy
         self.engine = engine
 
+    def get_column_default(self, column):
+        if column.default is not None:
+            return column.default.accept_visitor(self)
+        else:
+            return None
+
     def visit_sequence(self, seq):
         """sequences are not supported by default"""
         return None
@@ -425,11 +431,7 @@ class SQLEngine(schema.SchemaEngine):
                 need_lastrowid=False
                 for c in compiled.statement.table.c:
                     if not param.has_key(c.key) or param[c.key] is None:
-                        if c.default is not None:
-                            newid = c.default.accept_visitor(drunner)
-                        else:
-                            newid = None
-                            
+                        newid = drunner.get_column_default(c)
                         if newid is not None:
                             param[c.key] = newid
                             if c.primary_key:
@@ -481,6 +483,7 @@ class SQLEngine(schema.SchemaEngine):
                        post-processing on result-set values.
 
         commit      -  if True, will automatically commit the statement after completion. """
+        
         if parameters is None:
             parameters = {}
 
@@ -545,6 +548,7 @@ class SQLEngine(schema.SchemaEngine):
                        post-processing on result-set values.
 
         commit      -  if True, will automatically commit the statement after completion. """
+        
         if parameters is None:
             parameters = {}
 
index 27479ae9d62ccaa64b0258341935b7a7f3f6f957..f9b62b9f5a375e3fe78cf9b86e949886bd447f30 100644 (file)
@@ -466,7 +466,7 @@ class Mapper(object):
             # for this table, in the case that the user
             # specified custom primary key cols.
             for obj in objects:
-#                print "SAVE_OBJ we are " + hash_key(self) + " obj: " +  obj.__class__.__name__ + repr(id(obj))
+                #print "SAVE_OBJ we are " + hash_key(self) + " obj: " +  obj.__class__.__name__ + repr(id(obj))
                 params = {}
 
                 isinsert = not hasattr(obj, "_instance_key")
index a22539cc6fbb82bdab6f44bff879a2d9198c4450..afad2116b2479c957177d1b5eff314d4eb0a2cfd 100644 (file)
@@ -358,6 +358,21 @@ class SaveTest(AssertMixin):
                     lambda: [{'user_id': objects[3].user.user_id, 'email_addresses_address_id': objects[3].address_id}]
                 ),
                 
+        ],
+        with_sequences=[
+                (
+                    "INSERT INTO users (user_id, user_name) VALUES (:user_id, :user_name)",
+                    lambda:{'user_name': 'imnewlyadded', 'user_id':db.last_inserted_ids()[0]}
+                ),
+                (
+                    "UPDATE email_addresses SET email_address=:email_address WHERE email_addresses.address_id = :email_addresses_address_id",
+                    lambda: [{'email_address': 'imnew@foo.bar', 'email_addresses_address_id': objects[2].address_id}]
+                ),
+                (
+                    "UPDATE email_addresses SET user_id=:user_id WHERE email_addresses.address_id = :email_addresses_address_id",
+                    lambda: [{'user_id': objects[3].user.user_id, 'email_addresses_address_id': objects[3].address_id}]
+                ),
+                
         ])
         l = sql.select([users, addresses], sql.and_(users.c.user_id==addresses.c.address_id, addresses.c.address_id==a.address_id)).execute()
         self.echo( repr(l.fetchone().row))
@@ -429,6 +444,12 @@ class SaveTest(AssertMixin):
                     "INSERT INTO email_addresses (user_id, email_address) VALUES (:user_id, :email_address)",
                     {'email_address': 'hi', 'user_id': 7}
                     ),
+                ],
+                with_sequences=[
+                    (
+                    "INSERT INTO email_addresses (address_id, user_id, email_address) VALUES (:address_id, :user_id, :email_address)",
+                    lambda:{'email_address': 'hi', 'user_id': 7, 'address_id':db.last_inserted_ids()[0]}
+                    ),
                 ]
         )
 
@@ -577,7 +598,6 @@ class SaveTest(AssertMixin):
         l = m.select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]), order_by=[items.c.item_name, keywords.c.name])
         self.assert_result(l, *data)
 
-        print "\n\n\nTESTTESTTEST"
         objects[4].item_name = 'item4updated'
         k = Keyword()
         k.name = 'yellow'
@@ -593,7 +613,21 @@ class SaveTest(AssertMixin):
             ("INSERT INTO itemkeywords (item_id, keyword_id) VALUES (:item_id, :keyword_id)",
             lambda: [{'item_id': objects[5].item_id, 'keyword_id': k.keyword_id}]
             )
-        ])
+        ],
+        
+        with_sequences = [
+            {
+                "UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id":
+                [{'item_name': 'item4updated', 'items_item_id': objects[4].item_id}]
+            ,
+                "INSERT INTO keywords (keyword_id, name) VALUES (:keyword_id, :name)":
+                lambda: {'name': 'yellow', 'keyword_id':db.last_inserted_ids()[0]}
+            },
+            ("INSERT INTO itemkeywords (item_id, keyword_id) VALUES (:item_id, :keyword_id)",
+            lambda: [{'item_id': objects[5].item_id, 'keyword_id': k.keyword_id}]
+            )
+        ]
+        )
 
         objects[2].keywords.append(k)
         dkid = objects[5].keywords[1].keyword_id
@@ -609,7 +643,6 @@ class SaveTest(AssertMixin):
                 )
         ])
         
-        print "NEXT TEST"
         objectstore.delete(objects[3])
         objectstore.commit()
         
@@ -765,7 +798,26 @@ class SaveTest2(AssertMixin):
                 "INSERT INTO email_addresses (rel_user_id, email_address) VALUES (:rel_user_id, :email_address)",
                 {'rel_user_id': 2, 'email_address': 'thesdf@asdf.com'}
                 )
-                ]
+                ],
+                
+                with_sequences = [
+                        (
+                            "INSERT INTO users (user_id, user_name) VALUES (:user_id, :user_name)",
+                            lambda: {'user_name': 'thesub', 'user_id':db.last_inserted_ids()[0]}
+                        ),
+                        (
+                        "INSERT INTO users (user_id, user_name) VALUES (:user_id, :user_name)",
+                            lambda: {'user_name': 'assdkfj', 'user_id':db.last_inserted_ids()[0]}
+                        ),
+                        (
+                        "INSERT INTO email_addresses (address_id, rel_user_id, email_address) VALUES (:address_id, :rel_user_id, :email_address)",
+                        lambda:{'rel_user_id': 1, 'email_address': 'bar@foo.com', 'address_id':db.last_inserted_ids()[0]}
+                        ),
+                        (
+                        "INSERT INTO email_addresses (address_id, rel_user_id, email_address) VALUES (:address_id, :rel_user_id, :email_address)",
+                        lambda:{'rel_user_id': 2, 'email_address': 'thesdf@asdf.com', 'address_id':db.last_inserted_ids()[0]}
+                        )
+                        ]
         )
 
 
index a5785f29922f5392fde2be47d7a57176ce0b1208..c1342a5fcb7ef60b85563828ae8c666a44a17efa 100644 (file)
@@ -70,8 +70,11 @@ class AssertMixin(PersistTest):
                     self.assert_row(value[0], getattr(rowobj, key), value[1])
             else:
                 self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % (key, getattr(rowobj, key), value))
-    def assert_sql(self, db, callable_, list):
-        db.set_assert_list(self, list)
+    def assert_sql(self, db, callable_, list, with_sequences=None):
+        if with_sequences is not None and (db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle')):
+            db.set_assert_list(self, with_sequences)
+        else:
+            db.set_assert_list(self, list)
         try:
             callable_()
         finally:
@@ -87,8 +90,8 @@ class EngineAssert(object):
     """decorates a SQLEngine object to match the incoming queries against a set of assertions."""
     def __init__(self, engine):
         self.engine = engine
-        self.realexec = engine.pre_exec
-        engine.pre_exec = self.pre_exec
+        self.realexec = engine.post_exec
+        engine.post_exec = self.post_exec
         self.logger = engine.logger
         self.set_assert_list(None, None)
         self.sql_count = 0
@@ -102,7 +105,7 @@ class EngineAssert(object):
     def _set_echo(self, echo):
         self.engine.echo = echo
     echo = property(lambda s: s.engine.echo, _set_echo)
-    def pre_exec(self, proxy, compiled, parameters, **kwargs):
+    def post_exec(self, proxy, compiled, parameters, **kwargs):
         self.engine.logger = self.logger
         statement = str(compiled)
         statement = re.sub(r'\n', '', statement)
@@ -127,7 +130,6 @@ class EngineAssert(object):
                     if len(item) == 1:
                         self.assert_list.pop()
                     item = (statement, entry)
-                    print "OK ON", statement
                 except KeyError:
                     self.unittest.assert_(False, "Testing for one of the following queries: %s, received '%s'" % (repr([k for k in item.keys()]), statement))