From: Mike Bayer Date: Sat, 4 Mar 2006 19:26:23 +0000 (+0000) Subject: making sequences, column defaults independently executeable X-Git-Tag: rel_0_1_4~46 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7c0ff2178bb338f1792a6efb961effcde79eef8b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git making sequences, column defaults independently executeable --- diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 592bac79c8..105fe7a76f 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -218,7 +218,7 @@ class PGSQLEngine(ansisql.ANSISQLEngine): def schemadropper(self, **params): return PGSchemaDropper(self, **params) - def defaultrunner(self, proxy): + def defaultrunner(self, proxy=None): return PGDefaultRunner(self, proxy) def get_default_schema_name(self): @@ -346,7 +346,7 @@ class PGSchemaDropper(ansisql.ANSISchemaDropper): self.execute() class PGDefaultRunner(ansisql.ANSIDefaultRunner): - def get_column_default(self, column): + def get_column_default(self, column, isinsert=True): if column.primary_key: # passive defaults on primary keys have to be overridden if isinstance(column.default, schema.PassiveDefault): diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index d07dd57341..0f6b659093 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -265,7 +265,7 @@ class SQLEngine(schema.SchemaEngine): """ raise NotImplementedError() - def defaultrunner(self, proxy): + def defaultrunner(self, proxy=None): """Returns a schema.SchemaVisitor instance that can execute the default values on a column. The base class for this visitor is the DefaultRunner class inside this module. This visitor will typically only receive schema.DefaultGenerator schema objects. The given @@ -275,7 +275,7 @@ class SQLEngine(schema.SchemaEngine): defaultrunner is called within the context of the execute_compiled() method.""" return DefaultRunner(self, proxy) - + def compiler(self, statement, parameters): """returns a sql.ClauseVisitor which will produce a string representation of the given ClauseElement and parameter dictionary. This object is usually a subclass of @@ -529,7 +529,7 @@ class SQLEngine(schema.SchemaEngine): self.post_exec(proxy, compiled, parameters, **kwargs) return ResultProxy(cursor, self, typemap=compiled.typemap) - def execute(self, statement, parameters, connection=None, cursor=None, echo=None, typemap=None, commit=False, return_raw=False, **kwargs): + def execute(self, statement, parameters=None, connection=None, cursor=None, echo=None, typemap=None, commit=False, return_raw=False, **kwargs): """executes the given string-based SQL statement with the given parameters. The parameters can be a dictionary or a list, or a list of dictionaries or lists, depending diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index a11a1539e8..17e421f228 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -434,11 +434,12 @@ class ForeignKey(SchemaItem): self.parent.table.foreign_keys.append(self) class DefaultGenerator(SchemaItem): - """Base class for column "default" values, which can be a plain default - or a Sequence.""" + """Base class for column "default" values.""" def _set_parent(self, column): self.column = column self.column.default = self + def execute(self): + return self.accept_schema_visitor(self.engine.defaultrunner(self.engine.execute)) def __repr__(self): return "DefaultGenerator()" @@ -464,17 +465,27 @@ class ColumnDefault(DefaultGenerator): class Sequence(DefaultGenerator): """represents a sequence, which applies to Oracle and Postgres databases.""" - def __init__(self, name, start = None, increment = None, optional=False): + def __init__(self, name, start = None, increment = None, optional=False, engine=None): self.name = name self.start = start self.increment = increment self.optional=optional + self.engine = engine def __repr__(self): return "Sequence(%s)" % string.join( [repr(self.name)] + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['start', 'increment', 'optional']] , ',') - + def _set_parent(self, column): + super(Sequence, self)._set_parent(column) + column.sequence = self + if self.engine is None: + self.engine = column.table.engine + def create(self): + self.engine.create(self) + return self + def drop(self): + self.engine.drop(self) def accept_schema_visitor(self, visitor): """calls the visit_seauence method on the given visitor.""" return visitor.visit_sequence(self) diff --git a/test/query.py b/test/query.py index cf0bc94d32..a6e1bb4191 100644 --- a/test/query.py +++ b/test/query.py @@ -5,7 +5,7 @@ import unittest, sys, datetime import sqlalchemy.databases.sqlite as sqllite db = testbase.db -db.echo='debug' +#db.echo='debug' from sqlalchemy import * from sqlalchemy.engine import ResultProxy, RowProxy @@ -90,62 +90,6 @@ class QueryTest(PersistTest): finally: test_table.drop() - def testdefaults(self): - x = {'x':50} - def mydefault(): - x['x'] += 1 - return x['x'] - - use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle' - is_oracle = db.engine.name == 'oracle' - - # select "count(1)" from the DB which returns different results - # on different DBs - if is_oracle: - f = select([func.count(1) + 5], engine=db, from_obj=['DUAL']).scalar() - ts = select([func.sysdate()], engine=db, from_obj=['DUAL']).scalar() - def1 = func.sysdate() - def2 = text("sysdate") - deftype = Date - elif use_function_defaults: - f = select([func.count(1) + 5], engine=db).scalar() - def1 = func.current_date() - def2 = text("current_date") - deftype = Date - ts = select([func.current_date()], engine=db).scalar() - else: - f = select([func.count(1) + 5], engine=db).scalar() - def1 = def2 = "3" - ts = 3 - deftype = Integer - - t = Table('default_test1', db, - # python function - Column('col1', Integer, primary_key=True, default=mydefault), - - # python literal - Column('col2', String(20), default="imthedefault"), - - # preexecute expression - Column('col3', Integer, default=func.count(1) + 5), - - # SQL-side default from sql expression - Column('col4', deftype, PassiveDefault(def1)), - - # SQL-side default from literal expression - Column('col5', deftype, PassiveDefault(def2)) - ) - t.create() - try: - t.insert().execute() - self.assert_(t.engine.lastrow_has_defaults()) - t.insert().execute() - t.insert().execute() - - l = t.select().execute() - self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)]) - finally: - t.drop() def testdelete(self): c = db.connection() diff --git a/test/sequence.py b/test/sequence.py index 4d4390d18b..fcf852a86a 100644 --- a/test/sequence.py +++ b/test/sequence.py @@ -6,30 +6,106 @@ import testbase from sqlalchemy import * import sqlalchemy +db = testbase.db -class SequenceTest(PersistTest): +class DefaultTest(PersistTest): + + def testdefaults(self): + x = {'x':50} + def mydefault(): + x['x'] += 1 + return x['x'] + + use_function_defaults = db.engine.name == 'postgres' or db.engine.name == 'oracle' + is_oracle = db.engine.name == 'oracle' + + # select "count(1)" from the DB which returns different results + # on different DBs + if is_oracle: + f = select([func.count(1) + 5], engine=db, from_obj=['DUAL']).scalar() + ts = select([func.sysdate()], engine=db, from_obj=['DUAL']).scalar() + def1 = func.sysdate() + def2 = text("sysdate") + deftype = Date + elif use_function_defaults: + f = select([func.count(1) + 5], engine=db).scalar() + def1 = func.current_date() + def2 = text("current_date") + deftype = Date + ts = select([func.current_date()], engine=db).scalar() + else: + f = select([func.count(1) + 5], engine=db).scalar() + def1 = def2 = "3" + ts = 3 + deftype = Integer + + t = Table('default_test1', db, + # python function + Column('col1', Integer, primary_key=True, default=mydefault), + + # python literal + Column('col2', String(20), default="imthedefault"), + + # preexecute expression + Column('col3', Integer, default=func.count(1) + 5), + + # SQL-side default from sql expression + Column('col4', deftype, PassiveDefault(def1)), + + # SQL-side default from literal expression + Column('col5', deftype, PassiveDefault(def2)) + ) + t.create() + try: + t.insert().execute() + self.assert_(t.engine.lastrow_has_defaults()) + t.insert().execute() + t.insert().execute() + + l = t.select().execute() + self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts), (52, 'imthedefault', f, ts, ts), (53, 'imthedefault', f, ts, ts)]) + finally: + t.drop() - def setUp(self): - db = sqlalchemy.engine.create_engine('postgres', {'database':'test', 'host':'127.0.0.1', 'user':'scott', 'password':'tiger'}, echo=testbase.echo) - #db = sqlalchemy.engine.create_engine('oracle', {'dsn':os.environ['DSN'], 'user':os.environ['USER'], 'password':os.environ['PASSWORD']}, echo=testbase.echo) +class SequenceTest(PersistTest): - self.table = Table("cartitems", db, + def setUpAll(self): + if testbase.db.engine.name != 'postgres' and testbase.db.engine.name != 'oracle': + return + global cartitems + cartitems = Table("cartitems", db, Column("cart_id", Integer, Sequence('cart_id_seq'), primary_key=True), Column("description", String(40)), Column("createdate", DateTime()) ) - self.table.create() + cartitems.create() def testsequence(self): - self.table.insert().execute(description='hi') - self.table.insert().execute(description='there') - self.table.insert().execute(description='lala') + cartitems.insert().execute(description='hi') + cartitems.insert().execute(description='there') + cartitems.insert().execute(description='lala') - self.table.select().execute().fetchall() + cartitems.select().execute().fetchall() - def tearDown(self): - self.table.drop() + + def teststandalone(self): + s = Sequence("my_sequence", engine=db) + s.create() + try: + x =s.execute() + self.assert_(x == 1) + finally: + s.drop() + + def teststandalone2(self): + x = cartitems.c.cart_id.sequence.execute() + self.assert_(1 <= x <= 4) + + def tearDownAll(self): + if testbase.db.engine.name != 'postgres' and testbase.db.engine.name != 'oracle': + return + cartitems.drop() if __name__ == "__main__": unittest.main()