]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
making sequences, column defaults independently executeable
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Mar 2006 19:26:23 +0000 (19:26 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Mar 2006 19:26:23 +0000 (19:26 +0000)
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/schema.py
test/query.py
test/sequence.py

index 592bac79c8b4af46bd25f5e18ddf182a9fd37da2..105fe7a76fe0410052c5cb32acbe2f476397d6a7 100644 (file)
@@ -218,7 +218,7 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
     def schemadropper(self, **params):
         return PGSchemaDropper(self, **params)
 
-    def defaultrunner(self, proxy):
+    def defaultrunner(self, proxy=None):
         return PGDefaultRunner(self, proxy)
         
     def get_default_schema_name(self):
@@ -346,7 +346,7 @@ class PGSchemaDropper(ansisql.ANSISchemaDropper):
             self.execute()
 
 class PGDefaultRunner(ansisql.ANSIDefaultRunner):
-    def get_column_default(self, column):
+    def get_column_default(self, column, isinsert=True):
         if column.primary_key:
             # passive defaults on primary keys have to be overridden
             if isinstance(column.default, schema.PassiveDefault):
index d07dd57341e18aedf9ce5ef88acb87179654fdd9..0f6b659093aab8de3d155fc5493db11867d2bb05 100644 (file)
@@ -265,7 +265,7 @@ class SQLEngine(schema.SchemaEngine):
         """
         raise NotImplementedError()
 
-    def defaultrunner(self, proxy):
+    def defaultrunner(self, proxy=None):
         """Returns a schema.SchemaVisitor instance that can execute the default values on a column.
         The base class for this visitor is the DefaultRunner class inside this module.
         This visitor will typically only receive schema.DefaultGenerator schema objects.  The given 
@@ -275,7 +275,7 @@ class SQLEngine(schema.SchemaEngine):
         
         defaultrunner is called within the context of the execute_compiled() method."""
         return DefaultRunner(self, proxy)
-        
+    
     def compiler(self, statement, parameters):
         """returns a sql.ClauseVisitor which will produce a string representation of the given
         ClauseElement and parameter dictionary.  This object is usually a subclass of 
@@ -529,7 +529,7 @@ class SQLEngine(schema.SchemaEngine):
         self.post_exec(proxy, compiled, parameters, **kwargs)
         return ResultProxy(cursor, self, typemap=compiled.typemap)
 
-    def execute(self, statement, parameters, connection=None, cursor=None, echo=None, typemap=None, commit=False, return_raw=False, **kwargs):
+    def execute(self, statement, parameters=None, connection=None, cursor=None, echo=None, typemap=None, commit=False, return_raw=False, **kwargs):
         """executes the given string-based SQL statement with the given parameters.  
 
         The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending
index a11a1539e82287d87e588dabf7f935d0bb1f0899..17e421f22805d7fe49d0a919f600e39e5e0058d0 100644 (file)
@@ -434,11 +434,12 @@ class ForeignKey(SchemaItem):
         self.parent.table.foreign_keys.append(self)
 
 class DefaultGenerator(SchemaItem):
-    """Base class for column "default" values, which can be a plain default
-    or a Sequence."""
+    """Base class for column "default" values."""
     def _set_parent(self, column):
         self.column = column
         self.column.default = self
+    def execute(self):
+        return self.accept_schema_visitor(self.engine.defaultrunner(self.engine.execute))
     def __repr__(self):
         return "DefaultGenerator()"
 
@@ -464,17 +465,27 @@ class ColumnDefault(DefaultGenerator):
         
 class Sequence(DefaultGenerator):
     """represents a sequence, which applies to Oracle and Postgres databases."""
-    def __init__(self, name, start = None, increment = None, optional=False):
+    def __init__(self, name, start = None, increment = None, optional=False, engine=None):
         self.name = name
         self.start = start
         self.increment = increment
         self.optional=optional
+        self.engine = engine
     def __repr__(self):
         return "Sequence(%s)" % string.join(
              [repr(self.name)] +
              ["%s=%s" % (k, repr(getattr(self, k))) for k in ['start', 'increment', 'optional']]
             , ',')
-    
+    def _set_parent(self, column):
+        super(Sequence, self)._set_parent(column)
+        column.sequence = self
+        if self.engine is None:
+            self.engine = column.table.engine
+    def create(self):
+       self.engine.create(self)
+       return self
+    def drop(self):
+       self.engine.drop(self)
     def accept_schema_visitor(self, visitor):
         """calls the visit_seauence method on the given visitor."""
         return visitor.visit_sequence(self)
index cf0bc94d32c6011b31658ebe904963f2eb06abc8..a6e1bb4191d52311b9ab329548eb965670913254 100644 (file)
@@ -5,7 +5,7 @@ import unittest, sys, datetime
 import sqlalchemy.databases.sqlite as sqllite
 
 db = testbase.db
-db.echo='debug'
+#db.echo='debug'
 from sqlalchemy import *
 from sqlalchemy.engine import ResultProxy, RowProxy
 
@@ -90,62 +90,6 @@ class QueryTest(PersistTest):
         finally:
             test_table.drop()
 
-    def testdefaults(self):
-        x = {'x':50}
-        def mydefault():
-            x['x'] += 1
-            return x['x']
-
-        use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle'
-        is_oracle = db.engine.name == 'oracle'
-        # select "count(1)" from the DB which returns different results
-        # on different DBs
-        if is_oracle:
-            f = select([func.count(1) + 5], engine=db, from_obj=['DUAL']).scalar()
-            ts = select([func.sysdate()], engine=db, from_obj=['DUAL']).scalar()
-            def1 = func.sysdate()
-            def2 = text("sysdate")
-            deftype = Date
-        elif use_function_defaults:
-            f = select([func.count(1) + 5], engine=db).scalar()
-            def1 = func.current_date()
-            def2 = text("current_date")
-            deftype = Date
-            ts = select([func.current_date()], engine=db).scalar()
-        else:
-            f = select([func.count(1) + 5], engine=db).scalar()
-            def1 = def2 = "3"
-            ts = 3
-            deftype = Integer
-            
-        t = Table('default_test1', db,
-            # python function
-            Column('col1', Integer, primary_key=True, default=mydefault),
-            
-            # python literal
-            Column('col2', String(20), default="imthedefault"),
-            
-            # preexecute expression
-            Column('col3', Integer, default=func.count(1) + 5),
-            
-            # SQL-side default from sql expression
-            Column('col4', deftype, PassiveDefault(def1)),
-            
-            # SQL-side default from literal expression
-            Column('col5', deftype, PassiveDefault(def2))
-        )
-        t.create()
-        try:
-            t.insert().execute()
-            self.assert_(t.engine.lastrow_has_defaults())
-            t.insert().execute()
-            t.insert().execute()
-        
-            l = t.select().execute()
-            self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)])
-        finally:
-            t.drop()
         
     def testdelete(self):
         c = db.connection()
index 4d4390d18b57a19d500df395054cac3ed1dcaa5e..fcf852a86a77b3a0a0a8b2b5743c136579b86ca1 100644 (file)
@@ -6,30 +6,106 @@ import testbase
 from sqlalchemy import *
 import sqlalchemy
 
+db = testbase.db
 
-class SequenceTest(PersistTest):
+class DefaultTest(PersistTest):
+
+    def testdefaults(self):
+        x = {'x':50}
+        def mydefault():
+            x['x'] += 1
+            return x['x']
+
+        use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle'
+        is_oracle = db.engine.name == 'oracle'
+        # select "count(1)" from the DB which returns different results
+        # on different DBs
+        if is_oracle:
+            f = select([func.count(1) + 5], engine=db, from_obj=['DUAL']).scalar()
+            ts = select([func.sysdate()], engine=db, from_obj=['DUAL']).scalar()
+            def1 = func.sysdate()
+            def2 = text("sysdate")
+            deftype = Date
+        elif use_function_defaults:
+            f = select([func.count(1) + 5], engine=db).scalar()
+            def1 = func.current_date()
+            def2 = text("current_date")
+            deftype = Date
+            ts = select([func.current_date()], engine=db).scalar()
+        else:
+            f = select([func.count(1) + 5], engine=db).scalar()
+            def1 = def2 = "3"
+            ts = 3
+            deftype = Integer
+            
+        t = Table('default_test1', db,
+            # python function
+            Column('col1', Integer, primary_key=True, default=mydefault),
+            
+            # python literal
+            Column('col2', String(20), default="imthedefault"),
+            
+            # preexecute expression
+            Column('col3', Integer, default=func.count(1) + 5),
+            
+            # SQL-side default from sql expression
+            Column('col4', deftype, PassiveDefault(def1)),
+            
+            # SQL-side default from literal expression
+            Column('col5', deftype, PassiveDefault(def2))
+        )
+        t.create()
+        try:
+            t.insert().execute()
+            self.assert_(t.engine.lastrow_has_defaults())
+            t.insert().execute()
+            t.insert().execute()
+        
+            l = t.select().execute()
+            self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)])
+        finally:
+            t.drop()
 
-    def setUp(self):
-        db = sqlalchemy.engine.create_engine('postgres', {'database':'test', 'host':'127.0.0.1', 'user':'scott', 'password':'tiger'}, echo=testbase.echo)
-        #db = sqlalchemy.engine.create_engine('oracle', {'dsn':os.environ['DSN'], 'user':os.environ['USER'], 'password':os.environ['PASSWORD']}, echo=testbase.echo)
+class SequenceTest(PersistTest):
 
-        self.table = Table("cartitems", db, 
+    def setUpAll(self):
+        if testbase.db.engine.name != 'postgres' and testbase.db.engine.name != 'oracle':
+            return
+        global cartitems
+        cartitems = Table("cartitems", db, 
             Column("cart_id", Integer, Sequence('cart_id_seq'), primary_key=True),
             Column("description", String(40)),
             Column("createdate", DateTime())
         )
         
-        self.table.create()
+        cartitems.create()
 
     def testsequence(self):
-        self.table.insert().execute(description='hi')
-        self.table.insert().execute(description='there')
-        self.table.insert().execute(description='lala')
+        cartitems.insert().execute(description='hi')
+        cartitems.insert().execute(description='there')
+        cartitems.insert().execute(description='lala')
         
-        self.table.select().execute().fetchall()
+        cartitems.select().execute().fetchall()
    
-    def tearDown(self): 
-       self.table.drop()
+   
+    def teststandalone(self):
+        s = Sequence("my_sequence", engine=db)
+        s.create()
+        try:
+            x =s.execute()
+            self.assert_(x == 1)
+        finally:
+            s.drop()
+    
+    def teststandalone2(self):
+        x = cartitems.c.cart_id.sequence.execute()
+        self.assert_(1 <= x <= 4)
+        
+    def tearDownAll(self): 
+        if testbase.db.engine.name != 'postgres' and testbase.db.engine.name != 'oracle':
+            return
+        cartitems.drop()
 
 if __name__ == "__main__":
     unittest.main()