def schemadropper(self, **params):
return PGSchemaDropper(self, **params)
- def defaultrunner(self, proxy):
+ def defaultrunner(self, proxy=None):
return PGDefaultRunner(self, proxy)
def get_default_schema_name(self):
self.execute()
class PGDefaultRunner(ansisql.ANSIDefaultRunner):
- def get_column_default(self, column):
+ def get_column_default(self, column, isinsert=True):
if column.primary_key:
# passive defaults on primary keys have to be overridden
if isinstance(column.default, schema.PassiveDefault):
"""
raise NotImplementedError()
- def defaultrunner(self, proxy):
+ def defaultrunner(self, proxy=None):
"""Returns a schema.SchemaVisitor instance that can execute the default values on a column.
The base class for this visitor is the DefaultRunner class inside this module.
This visitor will typically only receive schema.DefaultGenerator schema objects. The given
defaultrunner is called within the context of the execute_compiled() method."""
return DefaultRunner(self, proxy)
-
+
def compiler(self, statement, parameters):
"""returns a sql.ClauseVisitor which will produce a string representation of the given
ClauseElement and parameter dictionary. This object is usually a subclass of
self.post_exec(proxy, compiled, parameters, **kwargs)
return ResultProxy(cursor, self, typemap=compiled.typemap)
- def execute(self, statement, parameters, connection=None, cursor=None, echo=None, typemap=None, commit=False, return_raw=False, **kwargs):
+ def execute(self, statement, parameters=None, connection=None, cursor=None, echo=None, typemap=None, commit=False, return_raw=False, **kwargs):
"""executes the given string-based SQL statement with the given parameters.
The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending
self.parent.table.foreign_keys.append(self)
class DefaultGenerator(SchemaItem):
- """Base class for column "default" values, which can be a plain default
- or a Sequence."""
+ """Base class for column "default" values."""
def _set_parent(self, column):
self.column = column
self.column.default = self
+ def execute(self):
+ return self.accept_schema_visitor(self.engine.defaultrunner(self.engine.execute))
def __repr__(self):
return "DefaultGenerator()"
class Sequence(DefaultGenerator):
"""represents a sequence, which applies to Oracle and Postgres databases."""
- def __init__(self, name, start = None, increment = None, optional=False):
+ def __init__(self, name, start = None, increment = None, optional=False, engine=None):
self.name = name
self.start = start
self.increment = increment
self.optional=optional
+ self.engine = engine
def __repr__(self):
return "Sequence(%s)" % string.join(
[repr(self.name)] +
["%s=%s" % (k, repr(getattr(self, k))) for k in ['start', 'increment', 'optional']]
, ',')
-
+ def _set_parent(self, column):
+ super(Sequence, self)._set_parent(column)
+ column.sequence = self
+ if self.engine is None:
+ self.engine = column.table.engine
+ def create(self):
+ self.engine.create(self)
+ return self
+ def drop(self):
+ self.engine.drop(self)
def accept_schema_visitor(self, visitor):
"""calls the visit_seauence method on the given visitor."""
return visitor.visit_sequence(self)
import sqlalchemy.databases.sqlite as sqllite
db = testbase.db
-db.echo='debug'
+#db.echo='debug'
from sqlalchemy import *
from sqlalchemy.engine import ResultProxy, RowProxy
finally:
test_table.drop()
- def testdefaults(self):
- x = {'x':50}
- def mydefault():
- x['x'] += 1
- return x['x']
-
- use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle'
- is_oracle = db.engine.name == 'oracle'
-
- # select "count(1)" from the DB which returns different results
- # on different DBs
- if is_oracle:
- f = select([func.count(1) + 5], engine=db, from_obj=['DUAL']).scalar()
- ts = select([func.sysdate()], engine=db, from_obj=['DUAL']).scalar()
- def1 = func.sysdate()
- def2 = text("sysdate")
- deftype = Date
- elif use_function_defaults:
- f = select([func.count(1) + 5], engine=db).scalar()
- def1 = func.current_date()
- def2 = text("current_date")
- deftype = Date
- ts = select([func.current_date()], engine=db).scalar()
- else:
- f = select([func.count(1) + 5], engine=db).scalar()
- def1 = def2 = "3"
- ts = 3
- deftype = Integer
-
- t = Table('default_test1', db,
- # python function
- Column('col1', Integer, primary_key=True, default=mydefault),
-
- # python literal
- Column('col2', String(20), default="imthedefault"),
-
- # preexecute expression
- Column('col3', Integer, default=func.count(1) + 5),
-
- # SQL-side default from sql expression
- Column('col4', deftype, PassiveDefault(def1)),
-
- # SQL-side default from literal expression
- Column('col5', deftype, PassiveDefault(def2))
- )
- t.create()
- try:
- t.insert().execute()
- self.assert_(t.engine.lastrow_has_defaults())
- t.insert().execute()
- t.insert().execute()
-
- l = t.select().execute()
- self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)])
- finally:
- t.drop()
def testdelete(self):
c = db.connection()
from sqlalchemy import *
import sqlalchemy
+db = testbase.db
-class SequenceTest(PersistTest):
+class DefaultTest(PersistTest):
+
+ def testdefaults(self):
+ x = {'x':50}
+ def mydefault():
+ x['x'] += 1
+ return x['x']
+
+ use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle'
+ is_oracle = db.engine.name == 'oracle'
+
+ # select "count(1)" from the DB which returns different results
+ # on different DBs
+ if is_oracle:
+ f = select([func.count(1) + 5], engine=db, from_obj=['DUAL']).scalar()
+ ts = select([func.sysdate()], engine=db, from_obj=['DUAL']).scalar()
+ def1 = func.sysdate()
+ def2 = text("sysdate")
+ deftype = Date
+ elif use_function_defaults:
+ f = select([func.count(1) + 5], engine=db).scalar()
+ def1 = func.current_date()
+ def2 = text("current_date")
+ deftype = Date
+ ts = select([func.current_date()], engine=db).scalar()
+ else:
+ f = select([func.count(1) + 5], engine=db).scalar()
+ def1 = def2 = "3"
+ ts = 3
+ deftype = Integer
+
+ t = Table('default_test1', db,
+ # python function
+ Column('col1', Integer, primary_key=True, default=mydefault),
+
+ # python literal
+ Column('col2', String(20), default="imthedefault"),
+
+ # preexecute expression
+ Column('col3', Integer, default=func.count(1) + 5),
+
+ # SQL-side default from sql expression
+ Column('col4', deftype, PassiveDefault(def1)),
+
+ # SQL-side default from literal expression
+ Column('col5', deftype, PassiveDefault(def2))
+ )
+ t.create()
+ try:
+ t.insert().execute()
+ self.assert_(t.engine.lastrow_has_defaults())
+ t.insert().execute()
+ t.insert().execute()
+
+ l = t.select().execute()
+ self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)])
+ finally:
+ t.drop()
- def setUp(self):
- db = sqlalchemy.engine.create_engine('postgres', {'database':'test', 'host':'127.0.0.1', 'user':'scott', 'password':'tiger'}, echo=testbase.echo)
- #db = sqlalchemy.engine.create_engine('oracle', {'dsn':os.environ['DSN'], 'user':os.environ['USER'], 'password':os.environ['PASSWORD']}, echo=testbase.echo)
+class SequenceTest(PersistTest):
- self.table = Table("cartitems", db,
+ def setUpAll(self):
+ if testbase.db.engine.name != 'postgres' and testbase.db.engine.name != 'oracle':
+ return
+ global cartitems
+ cartitems = Table("cartitems", db,
Column("cart_id", Integer, Sequence('cart_id_seq'), primary_key=True),
Column("description", String(40)),
Column("createdate", DateTime())
)
- self.table.create()
+ cartitems.create()
def testsequence(self):
- self.table.insert().execute(description='hi')
- self.table.insert().execute(description='there')
- self.table.insert().execute(description='lala')
+ cartitems.insert().execute(description='hi')
+ cartitems.insert().execute(description='there')
+ cartitems.insert().execute(description='lala')
- self.table.select().execute().fetchall()
+ cartitems.select().execute().fetchall()
- def tearDown(self):
- self.table.drop()
+
+ def teststandalone(self):
+ s = Sequence("my_sequence", engine=db)
+ s.create()
+ try:
+ x =s.execute()
+ self.assert_(x == 1)
+ finally:
+ s.drop()
+
+ def teststandalone2(self):
+ x = cartitems.c.cart_id.sequence.execute()
+ self.assert_(1 <= x <= 4)
+
+ def tearDownAll(self):
+ if testbase.db.engine.name != 'postgres' and testbase.db.engine.name != 'oracle':
+ return
+ cartitems.drop()
if __name__ == "__main__":
unittest.main()