]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
streamlined engine.schemagenerator and engine.schemadropper methodology
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Feb 2006 20:50:41 +0000 (20:50 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 11 Feb 2006 20:50:41 +0000 (20:50 +0000)
added support for creating PassiveDefault (i.e. regular DEFAULT) on table columns
postgres can reflect default values via information_schema
added unittests for PassiveDefault values getting created, inserted, coming back in result sets

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/information_schema.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/schema.py
test/engines.py
test/query.py

index 9688cb67bf1e1627f41175bb544440c70294472a..3b4ae64a7015d342b6da53d7bf30f9092fe791b7 100644 (file)
@@ -20,11 +20,11 @@ def engine(**params):
 
 class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
 
-    def schemagenerator(self, proxy, **params):
-        return ANSISchemaGenerator(proxy, **params)
+    def schemagenerator(self, **params):
+        return ANSISchemaGenerator(self, **params)
     
-    def schemadropper(self, proxy, **params):
-        return ANSISchemaDropper(proxy, **params)
+    def schemadropper(self, **params):
+        return ANSISchemaDropper(self, **params)
 
     def compiler(self, statement, parameters, **kwargs):
         return ANSICompiler(self, statement, parameters, **kwargs)
@@ -492,7 +492,6 @@ class ANSICompiler(sql.Compiled):
 
 
 class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator):
-
     def get_column_specification(self, column, override_pk=False, first_pk=False):
         raise NotImplementedError()
         
@@ -521,6 +520,16 @@ class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator):
     def post_create_table(self, table):
         return ''
 
+    def get_column_default_string(self, column):
+        if isinstance(column.default, schema.PassiveDefault):
+            if not isinstance(column.default.arg, str):
+                arg = str(column.default.arg.compile(self.engine))
+            else:
+                arg = column.default.arg
+            return arg
+        else:
+            return None
+
     def visit_column(self, column):
         pass
     
index c0503c25cecc5694e0fbffcf9fbb1752cb3b7563..f6dd251cd67662845030c4d541312c8643c79367 100644 (file)
@@ -31,6 +31,7 @@ gen_columns = schema.Table("columns", generic_engine,
     Column("character_maximum_length", Integer),
     Column("numeric_precision", Integer),
     Column("numeric_scale", Integer),
+    Column("column_default", Integer),
     schema="information_schema")
     
 gen_constraints = schema.Table("table_constraints", generic_engine,
@@ -109,15 +110,16 @@ def reflecttable(engine, table, ischema_names, use_mysql=False):
         row = c.fetchone()
         if row is None:
             break
-#        print "row! " + repr(row)
+        #print "row! " + repr(row)
  #       continue
-        (name, type, nullable, charlen, numericprec, numericscale) = (
+        (name, type, nullable, charlen, numericprec, numericscale, default) = (
             row[columns.c.column_name], 
             row[columns.c.data_type], 
             row[columns.c.is_nullable] == 'YES', 
             row[columns.c.character_maximum_length],
             row[columns.c.numeric_precision],
             row[columns.c.numeric_scale],
+            row[columns.c.column_default]
             )
 
         args = []
@@ -127,7 +129,10 @@ def reflecttable(engine, table, ischema_names, use_mysql=False):
         coltype = ischema_names[type]
         #print "coltype " + repr(coltype) + " args " +  repr(args)
         coltype = coltype(*args)
-        table.append_item(schema.Column(name, coltype, nullable = nullable))
+        colargs= []
+        if default is not None:
+            colargs.append(PassiveDefault(default))
+        table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs))
 
     s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True)
     if not use_mysql:
index 6734274cdbdc4ee0205b5dc74da318b628788553..0afac7df390c6bcb58a53b210a0f6aeec76ae950 100644 (file)
@@ -132,8 +132,8 @@ class MySQLEngine(ansisql.ANSISQLEngine):
     def compiler(self, statement, bindparams, **kwargs):
         return MySQLCompiler(self, statement, bindparams, **kwargs)
 
-    def schemagenerator(self, proxy, **params):
-        return MySQLSchemaGenerator(proxy, **params)
+    def schemagenerator(self, **params):
+        return MySQLSchemaGenerator(self, **params)
 
     def get_default_schema_name(self):
         if not hasattr(self, '_default_schema_name'):
@@ -234,6 +234,13 @@ class MySQLTableImpl(sql.TableImpl):
         self.mysql_engine = mysql_engine
 
 class MySQLCompiler(ansisql.ANSICompiler):
+
+    def visit_function(self, func):
+        if len(func.clauses):
+            super(MySQLCompiler, self).visit_function(func)
+        else:
+            self.strings[func] = func.name
+
     def limit_clause(self, select):
         text = ""
         if select.limit is not None:
@@ -248,6 +255,9 @@ class MySQLCompiler(ansisql.ANSICompiler):
 class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, first_pk=False):
         colspec = column.name + " " + column.type.get_col_spec()
+        default = self.get_column_default_string(column)
+        if default is not None:
+            colspec += " DEFAULT " + default
 
         if not column.nullable:
             colspec += " NOT NULL"
index 857b0c2fce427f54746809cdd396574cc522f6d9..2ce07a3c6c8efaa45d04f7ff2063be9c28bc2bc3 100644 (file)
@@ -104,10 +104,10 @@ class OracleSQLEngine(ansisql.ANSISQLEngine):
     def compiler(self, statement, bindparams, **kwargs):
         return OracleCompiler(self, statement, bindparams, use_ansi=self._use_ansi, **kwargs)
 
-    def schemagenerator(self, proxy, **params):
-        return OracleSchemaGenerator(proxy, **params)
-    def schemadropper(self, proxy, **params):
-        return OracleSchemaDropper(proxy, **params)
+    def schemagenerator(self, **params):
+        return OracleSchemaGenerator(self, **params)
+    def schemadropper(self, **params):
+        return OracleSchemaDropper(self, **params)
     def defaultrunner(self, proxy):
         return OracleDefaultRunner(self, proxy)
         
@@ -227,6 +227,9 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, **kwargs):
         colspec = column.name
         colspec += " " + column.type.get_col_spec()
+        default = self.get_column_default_string(column)
+        if default is not None:
+            colspec += " DEFAULT " + default
 
         if not column.nullable:
             colspec += " NOT NULL"
index 9122c2afa10827603072098dc5d5ca535d896fcd..5d0a4e1729b076d16fa8f7e8ff5f967cddb5bf09 100644 (file)
@@ -192,11 +192,11 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
     def compiler(self, statement, bindparams, **kwargs):
         return PGCompiler(self, statement, bindparams, **kwargs)
 
-    def schemagenerator(self, proxy, **params):
-        return PGSchemaGenerator(proxy, **params)
+    def schemagenerator(self, **params):
+        return PGSchemaGenerator(self, **params)
 
-    def schemadropper(self, proxy, **params):
-        return PGSchemaDropper(proxy, **params)
+    def schemadropper(self, **params):
+        return PGSchemaDropper(self, **params)
 
     def defaultrunner(self, proxy):
         return PGDefaultRunner(self, proxy)
@@ -254,6 +254,12 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
 
 class PGCompiler(ansisql.ANSICompiler):
 
+    def visit_function(self, func):
+        if len(func.clauses):
+            super(PGCompiler, self).visit_function(func)
+        else:
+            self.strings[func] = func.name
+
     def visit_insert_column(self, column):
         # Postgres advises against OID usage and turns it off in 8.1,
         # effectively making cursor.lastrowid
@@ -273,14 +279,16 @@ class PGCompiler(ansisql.ANSICompiler):
         return text
         
 class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
+        
     def get_column_specification(self, column, override_pk=False, **kwargs):
         colspec = column.name
-        if isinstance(column.default, schema.PassiveDefault):
-            colspec += " DEFAULT " + column.default.text
-        elif column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
+        if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
             colspec += " SERIAL"
         else:
             colspec += " " + column.type.get_col_spec()
+            default = self.get_column_default_string(column)
+            if default is not None:
+                colspec += " DEFAULT " + default
 
         if not column.nullable:
             colspec += " NOT NULL"
index 83fb00205f7cc3da61ce8bee88876bae66784015..5401c350f3b991909511fb3ebfb6362a58524d5e 100644 (file)
@@ -148,8 +148,8 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine):
     def dbapi(self):
         return sqlite
 
-    def schemagenerator(self, proxy, **params):
-        return SQLiteSchemaGenerator(proxy, **params)
+    def schemagenerator(self, **params):
+        return SQLiteSchemaGenerator(self, **params)
 
     def reflecttable(self, table):
         c = self.execute("PRAGMA table_info(" + table.name + ")", {})
@@ -226,6 +226,10 @@ class SQLiteCompiler(ansisql.ANSICompiler):
 class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, **kwargs):
         colspec = column.name + " " + column.type.get_col_spec()
+        default = self.get_column_default_string(column)
+        if default is not None:
+            colspec += " DEFAULT " + default
+
         if not column.nullable:
             colspec += " NOT NULL"
         if column.primary_key and not override_pk:
index 29acdc665cc7f4d7fd4992ec86c899eaefd970d9..aa8e89ca4de128d6479fb25a13d44988104998c2 100644 (file)
@@ -103,13 +103,13 @@ def engine_descriptors():
     
 class SchemaIterator(schema.SchemaVisitor):
     """a visitor that can gather text into a buffer and execute the contents of the buffer."""
-    def __init__(self, sqlproxy, **params):
+    def __init__(self, engine, **params):
         """initializes this SchemaIterator and initializes its buffer.
         
         sqlproxy - a callable function returned by SQLEngine.proxy(), which executes a
         statement plus optional parameters.
         """
-        self.sqlproxy = sqlproxy
+        self.engine = engine
         self.buffer = StringIO.StringIO()
 
     def append(self, s):
@@ -120,7 +120,7 @@ class SchemaIterator(schema.SchemaVisitor):
         """executes the contents of the SchemaIterator's buffer using its sql proxy and
         clears out the buffer."""
         try:
-            return self.sqlproxy(self.buffer.getvalue())
+            return self.engine.execute(self.buffer.getvalue(), None)
         finally:
             self.buffer.truncate(0)
 
@@ -250,21 +250,17 @@ class SQLEngine(schema.SchemaEngine):
         """returns a sql.text() object for performing literal queries."""
         return sql.text(text, engine=self, *args, **kwargs)
         
-    def schemagenerator(self, proxy, **params):
+    def schemagenerator(self, **params):
         """returns a schema.SchemaVisitor instance that can generate schemas, when it is
-        invoked to traverse a set of schema objects.  The 
-        "proxy" argument is a callable will execute a given string SQL statement
-        and a dictionary or list of parameters.  
+        invoked to traverse a set of schema objects. 
         
         schemagenerator is called via the create() method.
         """
         raise NotImplementedError()
 
-    def schemadropper(self, proxy, **params):
+    def schemadropper(self, **params):
         """returns a schema.SchemaVisitor instance that can drop schemas, when it is
-        invoked to traverse a set of schema objects.  The 
-        "proxy" argument is a callable will execute a given string SQL statement
-        and a dictionary or list of parameters.  
+        invoked to traverse a set of schema objects. 
         
         schemagenerator is called via the drop() method.
         """
@@ -300,11 +296,11 @@ class SQLEngine(schema.SchemaEngine):
         
     def create(self, table, **params):
         """creates a table within this engine's database connection given a schema.Table object."""
-        table.accept_visitor(self.schemagenerator(self.proxy(), **params))
+        table.accept_visitor(self.schemagenerator(**params))
 
     def drop(self, table, **params):
         """drops a table within this engine's database connection given a schema.Table object."""
-        table.accept_visitor(self.schemadropper(self.proxy(), **params))
+        table.accept_visitor(self.schemadropper(**params))
 
     def compile(self, statement, parameters, **kwargs):
         """given a sql.ClauseElement statement plus optional bind parameters, creates a new
@@ -369,12 +365,6 @@ class SQLEngine(schema.SchemaEngine):
         """implementations might want to put logic here for turning autocommit on/off, etc."""
         connection.commit()
 
-    def proxy(self, **kwargs):
-        """provides a callable that will execute the given string statement and parameters.
-        The statement and parameters should be in the format specific to the particular database;
-        i.e. named or positional."""
-        return lambda s, p = None: self.execute(s, p, **kwargs)
-
     def connection(self):
         """returns a managed DBAPI connection from this SQLEngine's connection pool."""
         return self._pool.connect()
index 01b7c7a1138a987ed4afd7c4f3dd669d89d89f66..8e85fb310b1c34d5bcb615ff486bbd0d38bcf327 100644 (file)
@@ -19,7 +19,7 @@ from sqlalchemy.util import *
 from sqlalchemy.types import *
 import copy, re, string
 
-__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'SchemaEngine', 'SchemaVisitor']
+__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault']
 
 
 class SchemaItem(object):
@@ -418,12 +418,12 @@ class DefaultGenerator(SchemaItem):
 
 class PassiveDefault(DefaultGenerator):
     """a default that takes effect on the database side"""
-    def __init__(self, text):
-        self.text = text
+    def __init__(self, arg):
+        self.arg = arg
     def accept_visitor(self, visitor):
-        return visitor_visit_passive_default(self)
+        return visitor.visit_passive_default(self)
     def __repr__(self):
-        return "PassiveDefault(%s)" % repr(self.text)
+        return "PassiveDefault(%s)" % repr(self.arg)
         
 class ColumnDefault(DefaultGenerator):
     """A plain default value on a column.  this could correspond to a constant, 
index 75ac894a35a111907cf04d78b4d9c99135226fd8..f7bde7118a1582b39daed9e659c745253ed44972 100644 (file)
@@ -13,6 +13,16 @@ import unittest, re
 class EngineTest(PersistTest):
     def testbasic(self):
         # really trip it up with a circular reference
+        
+        use_function_defaults = testbase.db.engine.__module__.endswith('postgres') or testbase.db.engine.__module__.endswith('oracle')
+        
+        if use_function_defaults:
+            defval = func.current_date()
+            deftype = Date
+        else:
+            defval = "3"
+            deftype = Integer
+            
         users = Table('engine_users', testbase.db,
             Column('user_id', INT, primary_key = True),
             Column('user_name', VARCHAR(20), nullable = False),
@@ -25,6 +35,7 @@ class EngineTest(PersistTest):
             Column('test6', DateTime, nullable = False),
             Column('test7', String),
             Column('test8', Binary),
+            Column('test_passivedefault', deftype, PassiveDefault(defval)),
             Column('test9', Binary(100)),
             mysql_engine='InnoDB'
         )
index 9c2bcfe441f82c6a91408960b78b72e6d43c56c8..6c4e017cd0f3bc6ec776b7a5165bdf0e331780fc 100644 (file)
@@ -5,7 +5,7 @@ import unittest, sys, datetime
 import sqlalchemy.databases.sqlite as sqllite
 
 db = testbase.db
-
+db.echo='debug'
 from sqlalchemy import *
 from sqlalchemy.engine import ResultProxy, RowProxy
 
@@ -46,15 +46,28 @@ class QueryTest(PersistTest):
         def mydefault():
             x['x'] += 1
             return x['x']
-            
+
+        use_function_defaults = db.engine.__module__.endswith('postgres') or db.engine.__module__.endswith('oracle')
+        
         # select "count(1)" from the DB which returns different results
         # on different DBs
-        f = select([func.count(1)], engine=db).execute().fetchone()[0]
-        
+        f = select([func.count(1)], engine=db).scalar()
+        if use_function_defaults:
+            def1 = func.current_date()
+            def2 = "current_date"
+            deftype = Date
+            ts = select([func.current_date()], engine=db).scalar()
+        else:
+            def1 = def2 = "3"
+            ts = 3
+            deftype = Integer
+            
         t = Table('default_test1', db, 
             Column('col1', Integer, primary_key=True, default=mydefault),
             Column('col2', String(20), default="imthedefault"),
             Column('col3', Integer, default=func.count(1)),
+            Column('col4', deftype, PassiveDefault(def1)),
+            Column('col5', deftype, PassiveDefault(def2))
         )
         t.create()
         try:
@@ -63,7 +76,7 @@ class QueryTest(PersistTest):
             t.insert().execute()
         
             l = t.select().execute()
-            self.assert_(l.fetchall() == [(1, 'imthedefault', f), (2, 'imthedefault', f), (3, 'imthedefault', f)])
+            self.assert_(l.fetchall() == [(1, 'imthedefault', f, ts, ts), (2, 'imthedefault', f, ts, ts), (3, 'imthedefault', f, ts, ts)])
         finally:
             t.drop()