added support for creating PassiveDefault (i.e. regular DEFAULT) on table columns
postgres can reflect default values via information_schema
added unittests for PassiveDefault values getting created, inserted, coming back in result sets
class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
- def schemagenerator(self, proxy, **params):
- return ANSISchemaGenerator(proxy, **params)
+ def schemagenerator(self, **params):
+ return ANSISchemaGenerator(self, **params)
- def schemadropper(self, proxy, **params):
- return ANSISchemaDropper(proxy, **params)
+ def schemadropper(self, **params):
+ return ANSISchemaDropper(self, **params)
def compiler(self, statement, parameters, **kwargs):
return ANSICompiler(self, statement, parameters, **kwargs)
class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator):
-
def get_column_specification(self, column, override_pk=False, first_pk=False):
raise NotImplementedError()
def post_create_table(self, table):
return ''
+ def get_column_default_string(self, column):
+ if isinstance(column.default, schema.PassiveDefault):
+ if not isinstance(column.default.arg, str):
+ arg = str(column.default.arg.compile(self.engine))
+ else:
+ arg = column.default.arg
+ return arg
+ else:
+ return None
+
def visit_column(self, column):
pass
Column("character_maximum_length", Integer),
Column("numeric_precision", Integer),
Column("numeric_scale", Integer),
+ Column("column_default", Integer),
schema="information_schema")
gen_constraints = schema.Table("table_constraints", generic_engine,
row = c.fetchone()
if row is None:
break
-# print "row! " + repr(row)
+ #print "row! " + repr(row)
# continue
- (name, type, nullable, charlen, numericprec, numericscale) = (
+ (name, type, nullable, charlen, numericprec, numericscale, default) = (
row[columns.c.column_name],
row[columns.c.data_type],
row[columns.c.is_nullable] == 'YES',
row[columns.c.character_maximum_length],
row[columns.c.numeric_precision],
row[columns.c.numeric_scale],
+ row[columns.c.column_default]
)
args = []
coltype = ischema_names[type]
#print "coltype " + repr(coltype) + " args " + repr(args)
coltype = coltype(*args)
- table.append_item(schema.Column(name, coltype, nullable = nullable))
+ colargs= []
+ if default is not None:
+ colargs.append(PassiveDefault(default))
+ table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs))
s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True)
if not use_mysql:
def compiler(self, statement, bindparams, **kwargs):
return MySQLCompiler(self, statement, bindparams, **kwargs)
- def schemagenerator(self, proxy, **params):
- return MySQLSchemaGenerator(proxy, **params)
+ def schemagenerator(self, **params):
+ return MySQLSchemaGenerator(self, **params)
def get_default_schema_name(self):
if not hasattr(self, '_default_schema_name'):
self.mysql_engine = mysql_engine
class MySQLCompiler(ansisql.ANSICompiler):
+
+ def visit_function(self, func):
+ if len(func.clauses):
+ super(MySQLCompiler, self).visit_function(func)
+ else:
+ self.strings[func] = func.name
+
def limit_clause(self, select):
text = ""
if select.limit is not None:
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, first_pk=False):
colspec = column.name + " " + column.type.get_col_spec()
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
if not column.nullable:
colspec += " NOT NULL"
def compiler(self, statement, bindparams, **kwargs):
return OracleCompiler(self, statement, bindparams, use_ansi=self._use_ansi, **kwargs)
- def schemagenerator(self, proxy, **params):
- return OracleSchemaGenerator(proxy, **params)
- def schemadropper(self, proxy, **params):
- return OracleSchemaDropper(proxy, **params)
+ def schemagenerator(self, **params):
+ return OracleSchemaGenerator(self, **params)
+ def schemadropper(self, **params):
+ return OracleSchemaDropper(self, **params)
def defaultrunner(self, proxy):
return OracleDefaultRunner(self, proxy)
def get_column_specification(self, column, override_pk=False, **kwargs):
colspec = column.name
colspec += " " + column.type.get_col_spec()
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
if not column.nullable:
colspec += " NOT NULL"
def compiler(self, statement, bindparams, **kwargs):
return PGCompiler(self, statement, bindparams, **kwargs)
- def schemagenerator(self, proxy, **params):
- return PGSchemaGenerator(proxy, **params)
+ def schemagenerator(self, **params):
+ return PGSchemaGenerator(self, **params)
- def schemadropper(self, proxy, **params):
- return PGSchemaDropper(proxy, **params)
+ def schemadropper(self, **params):
+ return PGSchemaDropper(self, **params)
def defaultrunner(self, proxy):
return PGDefaultRunner(self, proxy)
class PGCompiler(ansisql.ANSICompiler):
+ def visit_function(self, func):
+ if len(func.clauses):
+ super(PGCompiler, self).visit_function(func)
+ else:
+ self.strings[func] = func.name
+
def visit_insert_column(self, column):
# Postgres advises against OID usage and turns it off in 8.1,
# effectively making cursor.lastrowid
return text
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
+
def get_column_specification(self, column, override_pk=False, **kwargs):
colspec = column.name
- if isinstance(column.default, schema.PassiveDefault):
- colspec += " DEFAULT " + column.default.text
- elif column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+ if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
colspec += " SERIAL"
else:
colspec += " " + column.type.get_col_spec()
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
if not column.nullable:
colspec += " NOT NULL"
def dbapi(self):
return sqlite
- def schemagenerator(self, proxy, **params):
- return SQLiteSchemaGenerator(proxy, **params)
+ def schemagenerator(self, **params):
+ return SQLiteSchemaGenerator(self, **params)
def reflecttable(self, table):
c = self.execute("PRAGMA table_info(" + table.name + ")", {})
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, **kwargs):
colspec = column.name + " " + column.type.get_col_spec()
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
if not column.nullable:
colspec += " NOT NULL"
if column.primary_key and not override_pk:
class SchemaIterator(schema.SchemaVisitor):
"""a visitor that can gather text into a buffer and execute the contents of the buffer."""
- def __init__(self, sqlproxy, **params):
+ def __init__(self, engine, **params):
"""initializes this SchemaIterator and initializes its buffer.
sqlproxy - a callable function returned by SQLEngine.proxy(), which executes a
statement plus optional parameters.
"""
- self.sqlproxy = sqlproxy
+ self.engine = engine
self.buffer = StringIO.StringIO()
def append(self, s):
"""executes the contents of the SchemaIterator's buffer using its sql proxy and
clears out the buffer."""
try:
- return self.sqlproxy(self.buffer.getvalue())
+ return self.engine.execute(self.buffer.getvalue(), None)
finally:
self.buffer.truncate(0)
"""returns a sql.text() object for performing literal queries."""
return sql.text(text, engine=self, *args, **kwargs)
- def schemagenerator(self, proxy, **params):
+ def schemagenerator(self, **params):
"""returns a schema.SchemaVisitor instance that can generate schemas, when it is
- invoked to traverse a set of schema objects. The
- "proxy" argument is a callable will execute a given string SQL statement
- and a dictionary or list of parameters.
+ invoked to traverse a set of schema objects.
schemagenerator is called via the create() method.
"""
raise NotImplementedError()
- def schemadropper(self, proxy, **params):
+ def schemadropper(self, **params):
"""returns a schema.SchemaVisitor instance that can drop schemas, when it is
- invoked to traverse a set of schema objects. The
- "proxy" argument is a callable will execute a given string SQL statement
- and a dictionary or list of parameters.
+ invoked to traverse a set of schema objects.
schemagenerator is called via the drop() method.
"""
def create(self, table, **params):
"""creates a table within this engine's database connection given a schema.Table object."""
- table.accept_visitor(self.schemagenerator(self.proxy(), **params))
+ table.accept_visitor(self.schemagenerator(**params))
def drop(self, table, **params):
"""drops a table within this engine's database connection given a schema.Table object."""
- table.accept_visitor(self.schemadropper(self.proxy(), **params))
+ table.accept_visitor(self.schemadropper(**params))
def compile(self, statement, parameters, **kwargs):
"""given a sql.ClauseElement statement plus optional bind parameters, creates a new
"""implementations might want to put logic here for turning autocommit on/off, etc."""
connection.commit()
- def proxy(self, **kwargs):
- """provides a callable that will execute the given string statement and parameters.
- The statement and parameters should be in the format specific to the particular database;
- i.e. named or positional."""
- return lambda s, p = None: self.execute(s, p, **kwargs)
-
def connection(self):
"""returns a managed DBAPI connection from this SQLEngine's connection pool."""
return self._pool.connect()
from sqlalchemy.types import *
import copy, re, string
-__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'SchemaEngine', 'SchemaVisitor']
+__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault']
class SchemaItem(object):
class PassiveDefault(DefaultGenerator):
"""a default that takes effect on the database side"""
- def __init__(self, text):
- self.text = text
+ def __init__(self, arg):
+ self.arg = arg
def accept_visitor(self, visitor):
- return visitor_visit_passive_default(self)
+ return visitor.visit_passive_default(self)
def __repr__(self):
- return "PassiveDefault(%s)" % repr(self.text)
+ return "PassiveDefault(%s)" % repr(self.arg)
class ColumnDefault(DefaultGenerator):
"""A plain default value on a column. this could correspond to a constant,
class EngineTest(PersistTest):
def testbasic(self):
# really trip it up with a circular reference
+
+ use_function_defaults = testbase.db.engine.__module__.endswith('postgres') or testbase.db.engine.__module__.endswith('oracle')
+
+ if use_function_defaults:
+ defval = func.current_date()
+ deftype = Date
+ else:
+ defval = "3"
+ deftype = Integer
+
users = Table('engine_users', testbase.db,
Column('user_id', INT, primary_key = True),
Column('user_name', VARCHAR(20), nullable = False),
Column('test6', DateTime, nullable = False),
Column('test7', String),
Column('test8', Binary),
+ Column('test_passivedefault', deftype, PassiveDefault(defval)),
Column('test9', Binary(100)),
mysql_engine='InnoDB'
)
import sqlalchemy.databases.sqlite as sqllite
db = testbase.db
-
+db.echo='debug'
from sqlalchemy import *
from sqlalchemy.engine import ResultProxy, RowProxy
def mydefault():
x['x'] += 1
return x['x']
-
+
+ use_function_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle')
+
# select "count(1)" from the DB which returns different results
# on different DBs
- f = select([func.count(1)], engine=db).execute().fetchone()[0]
-
+ f = select([func.count(1)], engine=db).scalar()
+ if use_function_defaults:
+ def1 = func.current_date()
+ def2 = "current_date"
+ deftype = Date
+ ts = select([func.current_date()], engine=db).scalar()
+ else:
+ def1 = def2 = "3"
+ ts = 3
+ deftype = Integer
+
t = Table('default_test1', db,
Column('col1', Integer, primary_key=True, default=mydefault),
Column('col2', String(20), default="imthedefault"),
Column('col3', Integer, default=func.count(1)),
+ Column('col4', deftype, PassiveDefault(def1)),
+ Column('col5', deftype, PassiveDefault(def2))
)
t.create()
try:
t.insert().execute()
l = t.select().execute()
- self.assert_(l.fetchall() == [(1, 'imthedefault', f), (2, 'imthedefault', f), (3, 'imthedefault', f)])
+ self.assert_(l.fetchall() == [(1, 'imthedefault', f, ts, ts), (2, 'imthedefault', f, ts, ts), (3, 'imthedefault', f, ts, ts)])
finally:
t.drop()