From: Mike Bayer Date: Sat, 11 Feb 2006 20:50:41 +0000 (+0000) Subject: streamlined engine.schemagenerator and engine.schemadropper methodology X-Git-Tag: rel_0_1_0~29 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=280274812261868e8f665f706cd27e06eaff4302;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git streamlined engine.schemagenerator and engine.schemadropper methodology added support for creating PassiveDefault (i.e. regular DEFAULT) on table columns postgres can reflect default values via information_schema added unittests for PassiveDefault values getting created, inserted, coming back in result sets --- diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 9688cb67bf..3b4ae64a70 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -20,11 +20,11 @@ def engine(**params): class ANSISQLEngine(sqlalchemy.engine.SQLEngine): - def schemagenerator(self, proxy, **params): - return ANSISchemaGenerator(proxy, **params) + def schemagenerator(self, **params): + return ANSISchemaGenerator(self, **params) - def schemadropper(self, proxy, **params): - return ANSISchemaDropper(proxy, **params) + def schemadropper(self, **params): + return ANSISchemaDropper(self, **params) def compiler(self, statement, parameters, **kwargs): return ANSICompiler(self, statement, parameters, **kwargs) @@ -492,7 +492,6 @@ class ANSICompiler(sql.Compiled): class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator): - def get_column_specification(self, column, override_pk=False, first_pk=False): raise NotImplementedError() @@ -521,6 +520,16 @@ class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator): def post_create_table(self, table): return '' + def get_column_default_string(self, column): + if isinstance(column.default, schema.PassiveDefault): + if not isinstance(column.default.arg, str): + arg = str(column.default.arg.compile(self.engine)) + else: + arg = column.default.arg + return arg + else: + return None + def visit_column(self, column): pass diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py index c0503c25ce..f6dd251cd6 100644 --- a/lib/sqlalchemy/databases/information_schema.py +++ b/lib/sqlalchemy/databases/information_schema.py @@ -31,6 +31,7 @@ gen_columns = schema.Table("columns", generic_engine, Column("character_maximum_length", Integer), Column("numeric_precision", Integer), Column("numeric_scale", Integer), + Column("column_default", Integer), schema="information_schema") gen_constraints = schema.Table("table_constraints", generic_engine, @@ -109,15 +110,16 @@ def reflecttable(engine, table, ischema_names, use_mysql=False): row = c.fetchone() if row is None: break -# print "row! " + repr(row) + #print "row! " + repr(row) # continue - (name, type, nullable, charlen, numericprec, numericscale) = ( + (name, type, nullable, charlen, numericprec, numericscale, default) = ( row[columns.c.column_name], row[columns.c.data_type], row[columns.c.is_nullable] == 'YES', row[columns.c.character_maximum_length], row[columns.c.numeric_precision], row[columns.c.numeric_scale], + row[columns.c.column_default] ) args = [] @@ -127,7 +129,10 @@ def reflecttable(engine, table, ischema_names, use_mysql=False): coltype = ischema_names[type] #print "coltype " + repr(coltype) + " args " + repr(args) coltype = coltype(*args) - table.append_item(schema.Column(name, coltype, nullable = nullable)) + colargs= [] + if default is not None: + colargs.append(PassiveDefault(default)) + table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs)) s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True) if not use_mysql: diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 6734274cdb..0afac7df39 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -132,8 +132,8 @@ class MySQLEngine(ansisql.ANSISQLEngine): def compiler(self, statement, bindparams, **kwargs): return MySQLCompiler(self, statement, bindparams, **kwargs) - def schemagenerator(self, proxy, **params): - return MySQLSchemaGenerator(proxy, **params) + def schemagenerator(self, **params): + return MySQLSchemaGenerator(self, **params) def get_default_schema_name(self): if not hasattr(self, '_default_schema_name'): @@ -234,6 +234,13 @@ class MySQLTableImpl(sql.TableImpl): self.mysql_engine = mysql_engine class MySQLCompiler(ansisql.ANSICompiler): + + def visit_function(self, func): + if len(func.clauses): + super(MySQLCompiler, self).visit_function(func) + else: + self.strings[func] = func.name + def limit_clause(self, select): text = "" if select.limit is not None: @@ -248,6 +255,9 @@ class MySQLCompiler(ansisql.ANSICompiler): class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, first_pk=False): colspec = column.name + " " + column.type.get_col_spec() + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default if not column.nullable: colspec += " NOT NULL" diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 857b0c2fce..2ce07a3c6c 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -104,10 +104,10 @@ class OracleSQLEngine(ansisql.ANSISQLEngine): def compiler(self, statement, bindparams, **kwargs): return OracleCompiler(self, statement, bindparams, use_ansi=self._use_ansi, **kwargs) - def schemagenerator(self, proxy, **params): - return OracleSchemaGenerator(proxy, **params) - def schemadropper(self, proxy, **params): - return OracleSchemaDropper(proxy, **params) + def schemagenerator(self, **params): + return OracleSchemaGenerator(self, **params) + def schemadropper(self, **params): + return OracleSchemaDropper(self, **params) def defaultrunner(self, proxy): return OracleDefaultRunner(self, proxy) @@ -227,6 +227,9 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name colspec += " " + column.type.get_col_spec() + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default if not column.nullable: colspec += " NOT NULL" diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 9122c2afa1..5d0a4e1729 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -192,11 +192,11 @@ class PGSQLEngine(ansisql.ANSISQLEngine): def compiler(self, statement, bindparams, **kwargs): return PGCompiler(self, statement, bindparams, **kwargs) - def schemagenerator(self, proxy, **params): - return PGSchemaGenerator(proxy, **params) + def schemagenerator(self, **params): + return PGSchemaGenerator(self, **params) - def schemadropper(self, proxy, **params): - return PGSchemaDropper(proxy, **params) + def schemadropper(self, **params): + return PGSchemaDropper(self, **params) def defaultrunner(self, proxy): return PGDefaultRunner(self, proxy) @@ -254,6 +254,12 @@ class PGSQLEngine(ansisql.ANSISQLEngine): class PGCompiler(ansisql.ANSICompiler): + def visit_function(self, func): + if len(func.clauses): + super(PGCompiler, self).visit_function(func) + else: + self.strings[func] = func.name + def visit_insert_column(self, column): # Postgres advises against OID usage and turns it off in 8.1, # effectively making cursor.lastrowid @@ -273,14 +279,16 @@ class PGCompiler(ansisql.ANSICompiler): return text class PGSchemaGenerator(ansisql.ANSISchemaGenerator): + def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name - if isinstance(column.default, schema.PassiveDefault): - colspec += " DEFAULT " + column.default.text - elif column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): + if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): colspec += " SERIAL" else: colspec += " " + column.type.get_col_spec() + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default if not column.nullable: colspec += " NOT NULL" diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 83fb00205f..5401c350f3 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -148,8 +148,8 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): def dbapi(self): return sqlite - def schemagenerator(self, proxy, **params): - return SQLiteSchemaGenerator(proxy, **params) + def schemagenerator(self, **params): + return SQLiteSchemaGenerator(self, **params) def reflecttable(self, table): c = self.execute("PRAGMA table_info(" + table.name + ")", {}) @@ -226,6 +226,10 @@ class SQLiteCompiler(ansisql.ANSICompiler): class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name + " " + column.type.get_col_spec() + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + if not column.nullable: colspec += " NOT NULL" if column.primary_key and not override_pk: diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 29acdc665c..aa8e89ca4d 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -103,13 +103,13 @@ def engine_descriptors(): class SchemaIterator(schema.SchemaVisitor): """a visitor that can gather text into a buffer and execute the contents of the buffer.""" - def __init__(self, sqlproxy, **params): + def __init__(self, engine, **params): """initializes this SchemaIterator and initializes its buffer. sqlproxy - a callable function returned by SQLEngine.proxy(), which executes a statement plus optional parameters. """ - self.sqlproxy = sqlproxy + self.engine = engine self.buffer = StringIO.StringIO() def append(self, s): @@ -120,7 +120,7 @@ class SchemaIterator(schema.SchemaVisitor): """executes the contents of the SchemaIterator's buffer using its sql proxy and clears out the buffer.""" try: - return self.sqlproxy(self.buffer.getvalue()) + return self.engine.execute(self.buffer.getvalue(), None) finally: self.buffer.truncate(0) @@ -250,21 +250,17 @@ class SQLEngine(schema.SchemaEngine): """returns a sql.text() object for performing literal queries.""" return sql.text(text, engine=self, *args, **kwargs) - def schemagenerator(self, proxy, **params): + def schemagenerator(self, **params): """returns a schema.SchemaVisitor instance that can generate schemas, when it is - invoked to traverse a set of schema objects. The - "proxy" argument is a callable will execute a given string SQL statement - and a dictionary or list of parameters. + invoked to traverse a set of schema objects. schemagenerator is called via the create() method. """ raise NotImplementedError() - def schemadropper(self, proxy, **params): + def schemadropper(self, **params): """returns a schema.SchemaVisitor instance that can drop schemas, when it is - invoked to traverse a set of schema objects. The - "proxy" argument is a callable will execute a given string SQL statement - and a dictionary or list of parameters. + invoked to traverse a set of schema objects. schemagenerator is called via the drop() method. """ @@ -300,11 +296,11 @@ class SQLEngine(schema.SchemaEngine): def create(self, table, **params): """creates a table within this engine's database connection given a schema.Table object.""" - table.accept_visitor(self.schemagenerator(self.proxy(), **params)) + table.accept_visitor(self.schemagenerator(**params)) def drop(self, table, **params): """drops a table within this engine's database connection given a schema.Table object.""" - table.accept_visitor(self.schemadropper(self.proxy(), **params)) + table.accept_visitor(self.schemadropper(**params)) def compile(self, statement, parameters, **kwargs): """given a sql.ClauseElement statement plus optional bind parameters, creates a new @@ -369,12 +365,6 @@ class SQLEngine(schema.SchemaEngine): """implementations might want to put logic here for turning autocommit on/off, etc.""" connection.commit() - def proxy(self, **kwargs): - """provides a callable that will execute the given string statement and parameters. - The statement and parameters should be in the format specific to the particular database; - i.e. named or positional.""" - return lambda s, p = None: self.execute(s, p, **kwargs) - def connection(self): """returns a managed DBAPI connection from this SQLEngine's connection pool.""" return self._pool.connect() diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 01b7c7a113..8e85fb310b 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -19,7 +19,7 @@ from sqlalchemy.util import * from sqlalchemy.types import * import copy, re, string -__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'SchemaEngine', 'SchemaVisitor'] +__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] class SchemaItem(object): @@ -418,12 +418,12 @@ class DefaultGenerator(SchemaItem): class PassiveDefault(DefaultGenerator): """a default that takes effect on the database side""" - def __init__(self, text): - self.text = text + def __init__(self, arg): + self.arg = arg def accept_visitor(self, visitor): - return visitor_visit_passive_default(self) + return visitor.visit_passive_default(self) def __repr__(self): - return "PassiveDefault(%s)" % repr(self.text) + return "PassiveDefault(%s)" % repr(self.arg) class ColumnDefault(DefaultGenerator): """A plain default value on a column. this could correspond to a constant, diff --git a/test/engines.py b/test/engines.py index 75ac894a35..f7bde7118a 100644 --- a/test/engines.py +++ b/test/engines.py @@ -13,6 +13,16 @@ import unittest, re class EngineTest(PersistTest): def testbasic(self): # really trip it up with a circular reference + + use_function_defaults = testbase.db.engine.__module__.endswith('postgres') or testbase.db.engine.__module__.endswith('oracle') + + if use_function_defaults: + defval = func.current_date() + deftype = Date + else: + defval = "3" + deftype = Integer + users = Table('engine_users', testbase.db, Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20), nullable = False), @@ -25,6 +35,7 @@ class EngineTest(PersistTest): Column('test6', DateTime, nullable = False), Column('test7', String), Column('test8', Binary), + Column('test_passivedefault', deftype, PassiveDefault(defval)), Column('test9', Binary(100)), mysql_engine='InnoDB' ) diff --git a/test/query.py b/test/query.py index 9c2bcfe441..6c4e017cd0 100644 --- a/test/query.py +++ b/test/query.py @@ -5,7 +5,7 @@ import unittest, sys, datetime import sqlalchemy.databases.sqlite as sqllite db = testbase.db - +db.echo='debug' from sqlalchemy import * from sqlalchemy.engine import ResultProxy, RowProxy @@ -46,15 +46,28 @@ class QueryTest(PersistTest): def mydefault(): x['x'] += 1 return x['x'] - + + use_function_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle') + # select "count(1)" from the DB which returns different results # on different DBs - f = select([func.count(1)], engine=db).execute().fetchone()[0] - + f = select([func.count(1)], engine=db).scalar() + if use_function_defaults: + def1 = func.current_date() + def2 = "current_date" + deftype = Date + ts = select([func.current_date()], engine=db).scalar() + else: + def1 = def2 = "3" + ts = 3 + deftype = Integer + t = Table('default_test1', db, Column('col1', Integer, primary_key=True, default=mydefault), Column('col2', String(20), default="imthedefault"), Column('col3', Integer, default=func.count(1)), + Column('col4', deftype, PassiveDefault(def1)), + Column('col5', deftype, PassiveDefault(def2)) ) t.create() try: @@ -63,7 +76,7 @@ class QueryTest(PersistTest): t.insert().execute() l = t.select().execute() - self.assert_(l.fetchall() == [(1, 'imthedefault', f), (2, 'imthedefault', f), (3, 'imthedefault', f)]) + self.assert_(l.fetchall() == [(1, 'imthedefault', f, ts, ts), (2, 'imthedefault', f, ts, ts), (3, 'imthedefault', f, ts, ts)]) finally: t.drop()