]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added free-form `DDL` statements, can be executed standalone or tied to the DDL...
authorJason Kirtland <jek@discorporate.us>
Tue, 5 Feb 2008 05:46:33 +0000 (05:46 +0000)
committerJason Kirtland <jek@discorporate.us>
Tue, 5 Feb 2008 05:46:33 +0000 (05:46 +0000)
- Added DDL event hooks, triggers callables before and after create / drop.

CHANGES
lib/sqlalchemy/__init__.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/util.py
test/engine/alltests.py
test/engine/ddlevents.py [new file with mode: 0644]

diff --git a/CHANGES b/CHANGES
index b1c1070bd0be7a9cb5231102fef5777cf07d75d1..4b202edd38535597882fe7d9679a1893870cc7ae 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -5,14 +5,22 @@ CHANGES
 0.4.3
 ------
 - sql
+    - Added "schema.DDL", an executable free-form DDL statement.
+      DDLs can be executed in isolation or attached to Table or
+      MetaData instances and executed automatically when those
+      objects are created and/or dropped.
+
+    - Added a callable-based DDL events interface, adds hooks
+      before and after Tables and MetaData create and drop.
+
     - Added "ilike()" operator to column operations.  Compiles
       to ILIKE on postgres, lower(x) LIKE lower(y) on all
       others. [ticket:727]
 
-    - added "now()" as a generic function; on SQLite and 
+    - added "now()" as a generic function; on SQLite and
       Oracle compiles as "CURRENT_TIMESTAMP"; "now()"
       on all others [ticket:943]
-      
+
     - the startswith(), endswith(), and contains() operators
       now concatenate the wildcard operator with the given
       operand in SQL, i.e. "'%' || <bindparam>" in all cases,
@@ -22,11 +30,11 @@ CHANGES
       operands properly [ticket:962]
 
     - added "autocommit=True" kwarg to select() and text(),
-      as well as generative autocommit() method on select(); 
-      for statements which modify the database through some 
+      as well as generative autocommit() method on select();
+      for statements which modify the database through some
       user-defined means other than the usual INSERT/UPDATE/
       DELETE etc., this flag will enable "autocommit" behavior
-      during execution if no transaction is in progress 
+      during execution if no transaction is in progress
       [ticket:915]
 
     - The '.c.' attribute on a selectable now gets an entry
index 0fc4e117ef9ed8dc98ec49bd6c0625dc778baa30..9e2d50c269d06a0b48a98b989ed3a6bdc95729d6 100644 (file)
@@ -24,7 +24,7 @@ from sqlalchemy.schema import \
     MetaData, ThreadLocalMetaData, Table, Column, ForeignKey, \
     Sequence, Index, ForeignKeyConstraint, PrimaryKeyConstraint, \
     CheckConstraint, UniqueConstraint, Constraint, \
-    PassiveDefault, ColumnDefault
+    PassiveDefault, ColumnDefault, DDL
 
 from sqlalchemy.engine import create_engine, engine_from_config
 
index 2bbbf398d6a2f468f1aeae8bae8b9766399182b4..9a7280065e40ca2c6c0bee7d30e6447f6ec9aaa6 100644 (file)
@@ -917,6 +917,13 @@ class Connection(Connectable):
         else:
             self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context)
 
+    def _execute_ddl(self, ddl, params, multiparams):
+        if params:
+            schema_item, params = params[0], params[1:]
+        else:
+            schema_item = None
+        return ddl(None, schema_item, self, *params, **multiparams)
+
     def _handle_dbapi_exception(self, e, statement, parameters, cursor):
         if getattr(self, '_reentrant_error', False):
             raise exceptions.DBAPIError.instance(None, None, e)
@@ -971,6 +978,7 @@ class Connection(Connectable):
         expression.ClauseElement : execute_clauseelement,
         Compiled : _execute_compiled,
         schema.SchemaItem:_execute_default,
+        schema.DDL: _execute_ddl,
         str.__mro__[-2] : _execute_text
     }
 
index 98e375507e5133c60c972c65df86709f7348d55c..64e9d203d79e1df2e2913339109e113ab22b2089 100644 (file)
@@ -6,17 +6,17 @@
 
 """The schema module provides the building blocks for database metadata.
 
-Each element within this module describes a database entity 
+Each element within this module describes a database entity
 which can be created and dropped, or is otherwise part of such an entity.
 Examples include tables, columns, sequences, and indexes.
 
-All entities are subclasses of [sqlalchemy.schema#SchemaItem], and as 
-defined in this module they are intended to be agnostic of any 
+All entities are subclasses of [sqlalchemy.schema#SchemaItem], and as
+defined in this module they are intended to be agnostic of any
 vendor-specific constructs.
 
 A collection of entities are grouped into a unit called [sqlalchemy.schema#MetaData].
 MetaData serves as a logical grouping of schema elements, and can also
-be associated with an actual database connection such that operations 
+be associated with an actual database connection such that operations
 involving the contained elements can contact the database as needed.
 
 Two of the elements here also build upon their "syntactic" counterparts,
@@ -35,7 +35,7 @@ __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index',
            'ForeignKeyConstraint', 'PrimaryKeyConstraint', 'CheckConstraint',
            'UniqueConstraint', 'DefaultGenerator', 'Constraint', 'MetaData',
            'ThreadLocalMetaData', 'SchemaVisitor', 'PassiveDefault',
-           'ColumnDefault']
+           'ColumnDefault', 'DDL']
 
 class SchemaItem(object):
     """Base class for items that define a database schema."""
@@ -53,11 +53,11 @@ class SchemaItem(object):
         """Associate with this SchemaItem's parent object."""
 
         raise NotImplementedError()
-    
+
     def get_children(self, **kwargs):
         """used to allow SchemaVisitor access"""
         return []
-        
+
     def __repr__(self):
         return "%s()" % self.__class__.__name__
 
@@ -67,7 +67,7 @@ class SchemaItem(object):
         m = self.metadata
         return m and m.bind or None
     bind = property(bind)
-    
+
     def info(self):
         try:
             return self._info
@@ -75,7 +75,7 @@ class SchemaItem(object):
             self._info = {}
             return self._info
     info = property(info)
-    
+
 
 def _get_table_key(name, schema):
     if schema is None:
@@ -113,6 +113,8 @@ class Table(SchemaItem, expression.TableClause):
 
     __metaclass__ = _TableSingleton
 
+    ddl_events = ('before-create', 'after-create', 'before-drop', 'after-drop')
+
     def __init__(self, name, metadata, *args, **kwargs):
         """Construct a Table.
 
@@ -134,7 +136,7 @@ class Table(SchemaItem, expression.TableClause):
 
         \**kwargs
           kwargs include:
-          
+
           schema
             The *schema name* for this table, which is
             required if the table resides in a schema other than the
@@ -150,14 +152,14 @@ class Table(SchemaItem, expression.TableClause):
             if autoload==True, this is an optional Engine or Connection
             instance to be used for the table reflection.  If ``None``,
             the underlying MetaData's bound connectable will be used.
-        
+
           include_columns
-            A list of strings indicating a subset of columns to be 
+            A list of strings indicating a subset of columns to be
             loaded via the ``autoload`` operation; table columns who
             aren't present in this list will not be represented on the resulting
-            ``Table`` object.  Defaults to ``None`` which indicates all 
+            ``Table`` object.  Defaults to ``None`` which indicates all
             columns should be reflected.
-        
+
           info
             Defaults to {}: A space to store application specific data;
             this must be a dictionary.
@@ -198,6 +200,7 @@ class Table(SchemaItem, expression.TableClause):
         self._columns = expression.ColumnCollection()
         self.primary_key = PrimaryKeyConstraint()
         self._foreign_keys = util.OrderedSet()
+        self.ddl_listeners = util.defaultdict(list)
         self.quote = kwargs.pop('quote', False)
         self.quote_schema = kwargs.pop('quote_schema', False)
         if self.schema is not None:
@@ -207,7 +210,7 @@ class Table(SchemaItem, expression.TableClause):
         self.owner = kwargs.pop('owner', None)
         if kwargs.get('info'):
             self._info = kwargs.pop('info')
-        
+
         autoload = kwargs.pop('autoload', False)
         autoload_with = kwargs.pop('autoload_with', None)
         include_columns = kwargs.pop('include_columns', None)
@@ -217,7 +220,7 @@ class Table(SchemaItem, expression.TableClause):
             raise TypeError("Invalid argument(s) for Table: %s" % repr(kwargs.keys()))
 
         self.kwargs = kwargs
-        
+
         self._set_parent(metadata)
         # load column definitions from the database if 'autoload' is defined
         # we do it after the table is in the singleton dictionary to support
@@ -227,15 +230,15 @@ class Table(SchemaItem, expression.TableClause):
                 autoload_with.reflecttable(self, include_columns=include_columns)
             else:
                 _bind_or_error(metadata).reflecttable(self, include_columns=include_columns)
-                
+
         # initialize all the column, etc. objects.  done after
         # reflection to allow user-overrides
         self._init_items(*args)
-    
+
     def key(self):
         return _get_table_key(self.name, self.schema)
     key = property(key)
-    
+
     def _export_columns(self, columns=None):
         # override FromClause's collection initialization logic; Table implements it differently
         pass
@@ -269,6 +272,37 @@ class Table(SchemaItem, expression.TableClause):
 
         constraint._set_parent(self)
 
+    def append_ddl_listener(self, event, listener):
+        """Append a DDL event listener to this ``Table``.
+
+        The ``listener`` callable will be triggered when this ``Table`` is
+        created or dropped, either directly before or after the DDL is issued
+        to the database.  The listener may modify the Table, but may not abort
+        the event itself.
+
+        Arguments are:
+
+        event
+          One of ``Table.ddl_events``; e.g. 'before-create', 'after-create',
+          'before-drop' or 'after-drop'.
+
+        listener
+          A callable, invoked with three positional arguments:
+
+          event
+            The event currently being handled
+          schema_item
+            The ``Table`` object being created or dropped
+          bind
+            The ``Connection`` bueing used for DDL execution.
+
+        Listeners are added to the Table's ``ddl_listeners`` attribute.
+        """
+
+        if event not in self.ddl_events:
+            raise LookupError(event)
+        self.ddl_listeners[event].append(listener)
+
     def _set_parent(self, metadata):
         metadata.tables[_get_table_key(self.name, self.schema)] = self
         self.metadata = metadata
@@ -448,7 +482,7 @@ class Column(SchemaItem, expression._ColumnClause):
             self._info = kwargs.pop('info')
         if kwargs:
             raise exceptions.ArgumentError("Unknown arguments passed to Column: " + repr(kwargs.keys()))
-    
+
     def __str__(self):
         if self.table is not None:
             if self.table.named_with_column:
@@ -461,7 +495,7 @@ class Column(SchemaItem, expression._ColumnClause):
     def bind(self):
         return self.table.bind
     bind = property(bind)
-    
+
     def references(self, column):
         """return true if this column references the given column via foreign key"""
         for fk in self.foreign_keys:
@@ -469,7 +503,7 @@ class Column(SchemaItem, expression._ColumnClause):
                 return True
         else:
             return False
-            
+
     def append_foreign_key(self, fk):
         fk._set_parent(self)
 
@@ -502,7 +536,7 @@ class Column(SchemaItem, expression._ColumnClause):
             table._columns.replace(self)
         else:
             self._pre_existing_column = None
-            
+
         if self.primary_key:
             table.primary_key.replace(self)
         elif self.key in table.primary_key:
@@ -616,13 +650,13 @@ class ForeignKey(SchemaItem):
         """Return True if the given table is referenced by this ``ForeignKey``."""
 
         return table.corresponding_column(self.column) is not None
-    
+
     def get_referent(self, table):
         """return the column in the given table referenced by this ``ForeignKey``, or
         None if this ``ForeignKey`` does not reference the given table.
         """
         return table.corresponding_column(self.column)
-        
+
     def column(self):
         # ForeignKey inits its remote column as late as possible, so tables can
         # be defined without dependencies
@@ -661,7 +695,7 @@ class ForeignKey(SchemaItem):
                     raise exceptions.ArgumentError("Could not create ForeignKey '%s' on table '%s': table '%s' has no column named '%s'" % (self._colspec, parenttable.name, table.name, str(e)))
             else:
                 self._column = self._colspec
-                
+
         # propigate TypeEngine to parent if it didnt have one
         if isinstance(self.parent.type, types.NullType):
             self.parent.type = self._column.type
@@ -671,14 +705,14 @@ class ForeignKey(SchemaItem):
 
     def _set_parent(self, column):
         self.parent = column
-        
+
         if self.parent._pre_existing_column is not None:
             # remove existing FK which matches us
             for fk in self.parent._pre_existing_column.foreign_keys:
                 if fk._colspec == self._colspec:
                     self.parent.table.foreign_keys.remove(fk)
                     self.parent.table.constraints.remove(fk.constraint)
-            
+
         if self.constraint is None and isinstance(self.parent.table, Table):
             self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete)
             self.parent.table.append_constraint(self.constraint)
@@ -904,12 +938,12 @@ class PrimaryKeyConstraint(Constraint):
         table.primary_key = self
         for c in self.__colnames:
             self.add(table.c[c])
-    
+
     def add(self, col):
         self.columns.add(col)
         col.primary_key=True
     append_column = add
-    
+
     def replace(self, col):
         self.columns.replace(col)
 
@@ -969,7 +1003,6 @@ class Index(SchemaItem):
         self.columns = []
         self.table = None
         self.unique = kwargs.pop('unique', False)
-
         self.kwargs = kwargs
 
         self._init_items(*columns)
@@ -1022,7 +1055,7 @@ class Index(SchemaItem):
 
 class MetaData(SchemaItem):
     """A collection of Tables and their associated schema constructs.
-    
+
     Holds a collection of Tables and an optional binding to an
     ``Engine`` or ``Connection``.  If bound, the
     [sqlalchemy.schema#Table] objects in the collection and their
@@ -1044,10 +1077,12 @@ class MetaData(SchemaItem):
     """
 
     __visit_name__ = 'metadata'
-    
+
+    ddl_events = ('before-create', 'after-create', 'before-drop', 'after-drop')
+
     def __init__(self, bind=None, reflect=False):
         """Create a new MetaData object.
-            
+
         bind
           An Engine or Connection to bind to.  May also be a string or
           URL instance, these are passed to create_engine() and this
@@ -1059,39 +1094,40 @@ class MetaData(SchemaItem):
           set.  For finer control over loaded tables, use the ``reflect``
           method of ``MetaData``.
 
-        """        
+        """
 
         self.tables = {}
         self.bind = bind
         self.metadata = self
+        self.ddl_listeners = util.defaultdict(list)
         if reflect:
             if not bind:
                 raise exceptions.ArgumentError(
                     "A bind must be supplied in conjunction with reflect=True")
             self.reflect()
-        
+
     def __repr__(self):
         return 'MetaData(%r)' % self.bind
 
     def __contains__(self, key):
         return key in self.tables
-        
+
     def __getstate__(self):
         return {'tables': self.tables}
 
     def __setstate__(self, state):
         self.tables = state['tables']
         self._bind = None
-        
+
     def is_bound(self):
         """True if this MetaData is bound to an Engine or Connection."""
 
         return self._bind is not None
-        
+
     # @deprecated
     def connect(self, bind, **kwargs):
         """Bind this MetaData to an Engine.
-            
+
         Use ``metadata.bind = <engine>`` or ``metadata.bind = <url>``.
 
         bind
@@ -1100,7 +1136,7 @@ class MetaData(SchemaItem):
           with ``\**kwargs`` to produce the engine which to connect to.
           Otherwise connects directly to the given ``Engine``.
         """
-        
+
         global URL
         if URL is None:
             from sqlalchemy.engine.url import URL
@@ -1118,9 +1154,9 @@ class MetaData(SchemaItem):
         automatically create a basic ``Engine`` for this bind
         with ``create_engine()``.
         """
-        
+
         return self._bind
-        
+
     def _bind_to(self, bind):
         """Bind this MetaData to an Engine, Connection, string or URL."""
 
@@ -1134,14 +1170,14 @@ class MetaData(SchemaItem):
         else:
             self._bind = bind
     bind = property(bind, _bind_to)
-    
+
     def clear(self):
         self.tables.clear()
 
     def remove(self, table):
-        # TODO: scan all other tables and remove FK _column 
+        # TODO: scan all other tables and remove FK _column
         del self.tables[table.key]
-        
+
     def table_iterator(self, reverse=True, tables=None):
         from sqlalchemy.sql.util import sort_tables
         if tables is None:
@@ -1169,7 +1205,7 @@ class MetaData(SchemaItem):
         only
           Optional.  Load only a sub-set of available named tables.  May
           be specified as a sequence of names or a callable.
-        
+
           If a sequence of names is provided, only those tables will be
           reflected.  An error is raised if a table is requested but not
           available.  Named tables already present in this ``MetaData`` are
@@ -1213,6 +1249,44 @@ class MetaData(SchemaItem):
         for name in load:
             Table(name, self, **reflect_opts)
 
+    def append_ddl_listener(self, event, listener):
+        """Append a DDL event listener to this ``MetaData``.
+
+        The ``listener`` callable will be triggered when this ``MetaData`` is
+        involved in DDL creates or drops, and will be invoked either before
+        all Table-related actions or after.
+
+        Arguments are:
+
+        event
+          One of ``MetaData.ddl_events``; 'before-create', 'after-create',
+          'before-drop' or 'after-drop'.
+        listener
+          A callable, invoked with three positional arguments:
+
+          event
+            The event currently being handled
+          schema_item
+            The ``MetaData`` object being operated upon
+          bind
+            The ``Connection`` bueing used for DDL execution.
+
+        Listeners are added to the MetaData's ``ddl_listeners`` attribute.
+
+        Note: MetaData listeners are invoked even when ``Tables`` are created
+        in isolation.  This may change in a future release. I.e.::
+
+          # triggers all MetaData and Table listeners:
+          metadata.create_all()
+
+          # triggers MetaData listeners too:
+          some.table.create()
+        """
+
+        if event not in self.ddl_events:
+            raise LookupError(event)
+        self.ddl_listeners[event].append(listener)
+
     def create_all(self, bind=None, tables=None, checkfirst=True):
         """Create all tables stored in this metadata.
 
@@ -1230,7 +1304,11 @@ class MetaData(SchemaItem):
 
         if bind is None:
             bind = _bind_or_error(self)
+        for listener in self.ddl_listeners['before-create']:
+            listener('before-create', self, bind)
         bind.create(self, checkfirst=checkfirst, tables=tables)
+        for listener in self.ddl_listeners['after-create']:
+            listener('after-create', self, bind)
 
     def drop_all(self, bind=None, tables=None, checkfirst=True):
         """Drop all tables stored in this metadata.
@@ -1241,7 +1319,7 @@ class MetaData(SchemaItem):
         bind
           A ``Connectable`` used to access the database; if None, uses
           the existing bind on this ``MetaData``, if any.
-          
+
         tables
           Optional list of ``Table`` objects, which is a subset of the
           total tables in the ``MetaData`` (others are ignored).
@@ -1249,8 +1327,12 @@ class MetaData(SchemaItem):
 
         if bind is None:
             bind = _bind_or_error(self)
+        for listener in self.ddl_listeners['before-drop']:
+            listener('before-drop', self, bind)
         bind.drop(self, checkfirst=checkfirst, tables=tables)
-    
+        for listener in self.ddl_listeners['after-drop']:
+            listener('after-drop', self, bind)
+
 class ThreadLocalMetaData(MetaData):
     """A MetaData variant that presents a different ``bind`` in every thread.
 
@@ -1271,24 +1353,24 @@ class ThreadLocalMetaData(MetaData):
 
     def __init__(self):
         """Construct a ThreadLocalMetaData."""
-    
+
         self.context = util.ThreadLocal()
         self.__engines = {}
         super(ThreadLocalMetaData, self).__init__()
 
     # @deprecated
-    def connect(self, bind, **kwargs): 
+    def connect(self, bind, **kwargs):
         """Bind to an Engine in the caller's thread.
-            
+
         Use ``metadata.bind=<engine>`` or ``metadata.bind=<url>``.
-        
+
         bind
           A string, ``URL``, ``Engine`` or ``Connection`` instance.  If
           a string or ``URL``, will be passed to ``create_engine()`` along
           with ``\**kwargs`` to produce the engine which to connect to.
           Otherwise connects directly to the given ``Engine``.
         """
-       
+
         global URL
         if URL is None:
             from sqlalchemy.engine.url import URL
@@ -1308,7 +1390,7 @@ class ThreadLocalMetaData(MetaData):
         This property may be assigned an Engine or Connection,
         or assigned a string or URL to automatically create a
         basic Engine for this bind with ``create_engine()``."""
-        
+
         return getattr(self.context, '_engine', None)
 
     def _bind_to(self, bind):
@@ -1351,6 +1433,198 @@ class SchemaVisitor(visitors.ClauseVisitor):
 
     __traverse_options__ = {'schema_visitor':True}
 
+
+class DDL(object):
+    """A literal DDL statement.
+
+    Specifies literal SQL DDL to be executed by the database.  DDL objects can
+    be attached to ``Tables`` or ``MetaData`` instances, conditionally
+    executing SQL as part of the DDL lifecycle of those schema items.  Basic
+    templating support allows a single DDL instance to handle repetitive tasks
+    for multiple tables.
+
+    Examples::
+
+      tbl = Table('users', metadata, Column('uid', Integer)) # ...
+      DDL('DROP TRIGGER users_trigger').execute_at('before-create', tbl)
+
+      spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE', on='somedb')
+      spow.execute_at('after-create', tbl)
+
+      drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE')
+      connection.execute(drop_spow)
+    """
+
+    def __init__(self, statement, on=None, context=None):
+        """Create a DDL statement.
+
+        statement
+          A string or unicode string to be executed.  Statements will
+          be processed with Python's string formatting operator.  See the
+          ``context`` argument and the ``execute_at`` method.
+
+          A literal '%' in a statement must be escaped as '%%'.
+
+          Bind parameters are not available in DDL statements.
+
+        on
+          Optional filtering criteria.  May be a string or a callable
+          predicate.  If a string, it will be compared to the name of the
+          executing database dialect::
+
+            DDL('something', on='postgres')
+
+          If a callable, it will be invoked with three positional arguments:
+
+            event
+              The name of the event that has triggered this DDL, such
+              as 'after-create'  Will be None if the DDL is executed
+              explicitly.
+
+            schema_item
+              A SchemaItem instance, such as ``Table`` or ``MetaData``. May
+              be None if the DDL is executed explicitly.
+
+            connection
+              The ``Connection`` being used for DDL execution
+
+          If the callable returns a true value, the DDL statement will
+          be executed.
+
+        context
+          Optional dictionary, defaults to None.  These values will be
+          available for use in string substitutions on the DDL statement.
+        """
+
+        if not isinstance(statement, basestring):
+            raise exceptions.ArgumentError(
+                "Expected a string or unicode SQL statement, got '%r'" %
+                statement)
+        if (on is not None and
+            (not isinstance(on, basestring) and not callable(on))):
+            raise exceptions.ArgumentError(
+                "Expected the name of a database dialect or a callable for "
+                "'on' criteria, got type '%s'." % type(on).__name__)
+
+        self.statement = statement
+        self.on = on
+        self.context = context or {}
+
+    def execute(self, bind, schema_item=None):
+        """Execute this DDL immediately.
+
+        Executes the DDL statement in isolation using the supplied
+        ``Connectable``.  If the DDL has a conditional ``on`` criteria, it
+        will be invoked with None as the event.
+
+        bind
+          An Engine or Connection
+
+        schema_item
+          Optional, defaults to None.  Will be passed to the ``on`` callable
+          criteria, if any, and may provide string expansion data for the
+          statement. See ``execute_at`` for more information.
+        """
+
+        # no bind params are supported
+        if self._should_execute(None, schema_item, bind):
+            executable = expression.text(self._expand(schema_item, bind))
+            return bind.execute(executable)
+        else:
+            bind.engine.logger.info("DDL execution skipped, criteria not met.")
+
+    def execute_at(self, event, schema_item):
+        """Link execution of this DDL to the DDL lifecycle of a SchemaItem.
+
+        Links this ``DDL`` to a ``Table`` or ``MetaData`` instance, executing
+        it when that schema item is created or dropped.
+
+        event
+          One of the events defined in the schema item's ``.ddl_events``;
+          e.g. 'before-create', 'after-create', 'before-drop' or 'after-drop'
+
+        schema_item
+          A Table or MetaData instance
+
+        When operating on Table events, the following additional ``statement``
+        string substitions are available::
+
+            %(table)s  - the Table name, with any required quoting applied
+            %(schema)s - the schema name, with any required quoting applied
+            %(fullname)s - the Table name including schema, quoted if needed
+
+        The DDL's ``context``, if any, will be combined with the standard
+        substutions noted above.  Keys present in the context will override
+        the standard substitutions.
+
+        A DDL instance can be linked to any number of schema items. The
+        statement subsitution support allows for DDL instances to be used in a
+        template fashion.
+
+        ``execute_at`` builds on the ``append_ddl_listener`` interface of
+        MetaDta and Table objects.
+
+        Caveat: Creating or dropping a Table in isolation will also trigger
+        any DDL set to ``execute_at`` that Table's MetaData.  This may change
+        in a future release.
+        """
+
+        if not hasattr(schema_item, 'ddl_listeners'):
+            raise exceptions.ArgumentError(
+                "%s does not support DDL events" % type(schema_item).__name__)
+        if event not in schema_item.ddl_events:
+            raise exceptions.ArgumentError(
+                "Unknown event, expected one of (%s), got '%r'" %
+                (', '.join(schema_item.ddl_events), event))
+        schema_item.ddl_listeners[event].append(self)
+        return self
+
+    def __call__(self, event, schema_item, bind):
+        """Execute the DDL as a ddl_listener."""
+
+        if self._should_execute(event, schema_item, bind):
+            statement = expression.text(self._expand(schema_item, bind))
+            return bind.execute(statement)
+
+    def _expand(self, schema_item, bind):
+        return self.statement % self._prepare_context(schema_item, bind)
+
+    def _should_execute(self, event, schema_item, bind):
+        if self.on is None:
+            return True
+        elif isinstance(self.on, basestring):
+            return self.on == bind.engine.name
+        else:
+            return self.on(event, schema_item, bind)
+
+    def _prepare_context(self, schema_item, bind):
+        # table events can substitute table and schema name
+        if isinstance(schema_item, Table):
+            context = self.context.copy()
+
+            preparer = bind.dialect.identifier_preparer
+            path = preparer.format_table_seq(schema_item)
+            if len(path) == 1:
+                table, schema = path[0], ''
+            else:
+                table, schema = path[-1], path[0]
+
+            context.setdefault('table', table)
+            context.setdefault('schema', schema)
+            context.setdefault('fullname', preparer.format_table(schema_item))
+            return context
+        else:
+            return self.context
+
+    def __repr__(self):
+        return '<%s@%s; %s>' % (
+            type(self).__name__, id(self),
+            ', '.join([repr(self.statement)] +
+                      ['%s=%r' % (key, getattr(self, key))
+                       for key in ('on', 'context')
+                       if getattr(self, key)]))
+
+
 def _bind_or_error(schemaitem):
     bind = schemaitem.bind
     if not bind:
index cdb6804317c68fa2be2ec60fd205216491166ca8..4e73221c121ba4975838084029d3f83fa11e0ecf 100644 (file)
@@ -775,6 +775,9 @@ class SchemaGenerator(DDLBase):
                 self.add_foreignkey(alterable)
 
     def visit_table(self, table):
+        for listener in table.ddl_listeners['before-create']:
+            listener('before-create', table, self.connection)
+
         for column in table.columns:
             if column.default is not None:
                 self.traverse_single(column.default)
@@ -803,10 +806,14 @@ class SchemaGenerator(DDLBase):
 
         self.append("\n)%s\n\n" % self.post_create_table(table))
         self.execute()
+
         if hasattr(table, 'indexes'):
             for index in table.indexes:
                 self.traverse_single(index)
 
+        for listener in table.ddl_listeners['after-create']:
+            listener('after-create', table, self.connection)
+
     def post_create_table(self, table):
         return ''
 
@@ -892,6 +899,7 @@ class SchemaGenerator(DDLBase):
                        string.join([preparer.quote(c, c.name) for c in index.columns], ', ')))
         self.execute()
 
+
 class SchemaDropper(DDLBase):
     def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
         super(SchemaDropper, self).__init__(connection, **kwargs)
@@ -919,6 +927,9 @@ class SchemaDropper(DDLBase):
         self.execute()
 
     def visit_table(self, table):
+        for listener in table.ddl_listeners['before-drop']:
+            listener('before-drop', table, self.connection)
+
         for column in table.columns:
             if column.default is not None:
                 self.traverse_single(column.default)
@@ -926,6 +937,10 @@ class SchemaDropper(DDLBase):
         self.append("\nDROP TABLE " + self.preparer.format_table(table))
         self.execute()
 
+        for listener in table.ddl_listeners['after-drop']:
+            listener('after-drop', table, self.connection)
+
+
 class IdentifierPreparer(object):
     """Handle quoting and case-folding of identifiers based on options."""
 
index c0d0c7eed7b58c363103e6edf0f4b8078eeca752..af14119886a895df0b97086e72c163a550601413 100644 (file)
@@ -125,6 +125,44 @@ else:
                 self[key] = value = self.creator(key)
                 return value
 
+try:
+    from collections import defaultdict
+except ImportError:
+    class defaultdict(dict):
+        def __init__(self, default_factory=None, *a, **kw):
+            if (default_factory is not None and
+                not hasattr(default_factory, '__call__')):
+                raise TypeError('first argument must be callable')
+            dict.__init__(self, *a, **kw)
+            self.default_factory = default_factory
+        def __getitem__(self, key):
+            try:
+                return dict.__getitem__(self, key)
+            except KeyError:
+                return self.__missing__(key)
+        def __missing__(self, key):
+            if self.default_factory is None:
+                raise KeyError(key)
+            self[key] = value = self.default_factory()
+            return value
+        def __reduce__(self):
+            if self.default_factory is None:
+                args = tuple()
+            else:
+                args = self.default_factory,
+            return type(self), args, None, None, self.iteritems()
+        def copy(self):
+            return self.__copy__()
+        def __copy__(self):
+            return type(self)(self.default_factory, self)
+        def __deepcopy__(self, memo):
+            import copy
+            return type(self)(self.default_factory,
+                              copy.deepcopy(self.items()))
+        def __repr__(self):
+            return 'defaultdict(%s, %s)' % (self.default_factory,
+                                            dict.__repr__(self))
+
 def to_list(x, default=None):
     if x is None:
         return default
index 4e76292989432d4fc6048c2c618540da1cc19dc3..75167d5d6f13996e015cb0dddd2081995623c4c4 100644 (file)
@@ -15,6 +15,7 @@ def suite():
 
         # schema/tables
         'engine.reflection',
+        'engine.ddlevents',
 
         )
     alltests = unittest.TestSuite()
diff --git a/test/engine/ddlevents.py b/test/engine/ddlevents.py
new file mode 100644 (file)
index 0000000..e902ec7
--- /dev/null
@@ -0,0 +1,346 @@
+import testenv; testenv.configure_for_tests()
+from sqlalchemy import *
+from sqlalchemy.schema import DDL
+import sqlalchemy
+from testlib import *
+
+
+class DDLEventTest(PersistTest):
+    class Canary(object):
+        def __init__(self, schema_item, bind):
+            self.state = None
+            self.schema_item = schema_item
+            self.bind = bind
+
+        def before_create(self, action, schema_item, bind):
+            assert self.state is None
+            assert schema_item is self.schema_item
+            assert bind is self.bind
+            self.state = action
+
+        def after_create(self, action, schema_item, bind):
+            assert self.state in ('before-create', 'skipped')
+            assert schema_item is self.schema_item
+            assert bind is self.bind
+            self.state = action
+
+        def before_drop(self, action, schema_item, bind):
+            assert self.state is None
+            assert schema_item is self.schema_item
+            assert bind is self.bind
+            self.state = action
+
+        def after_drop(self, action, schema_item, bind):
+            assert self.state in ('before-drop', 'skipped')
+            assert schema_item is self.schema_item
+            assert bind is self.bind
+            self.state = action
+
+    def mock_engine(self):
+        buffer = []
+        def executor(sql, *a, **kw):
+            buffer.append(sql)
+        engine = create_engine(testing.db.name + '://',
+                               strategy='mock', executor=executor)
+        assert not hasattr(engine, 'mock')
+        engine.mock = buffer
+        return engine
+
+    def setUp(self):
+        self.bind = self.mock_engine()
+        self.metadata = MetaData()
+        self.table = Table('t', self.metadata, Column('id', Integer))
+
+    def test_table_create_before(self):
+        table, bind = self.table, self.bind
+        canary = self.Canary(table, bind)
+        table.ddl_listeners['before-create'].append(canary.before_create)
+
+        table.create(bind)
+        assert canary.state == 'before-create'
+        table.drop(bind)
+        assert canary.state == 'before-create'
+
+    def test_table_create_after(self):
+        table, bind = self.table, self.bind
+        canary = self.Canary(table, bind)
+        table.ddl_listeners['after-create'].append(canary.after_create)
+
+        canary.state = 'skipped'
+        table.create(bind)
+        assert canary.state == 'after-create'
+        table.drop(bind)
+        assert canary.state == 'after-create'
+
+    def test_table_create_both(self):
+        table, bind = self.table, self.bind
+        canary = self.Canary(table, bind)
+        table.ddl_listeners['before-create'].append(canary.before_create)
+        table.ddl_listeners['after-create'].append(canary.after_create)
+
+        table.create(bind)
+        assert canary.state == 'after-create'
+        table.drop(bind)
+        assert canary.state == 'after-create'
+
+    def test_table_drop_before(self):
+        table, bind = self.table, self.bind
+        canary = self.Canary(table, bind)
+        table.ddl_listeners['before-drop'].append(canary.before_drop)
+
+        table.create(bind)
+        assert canary.state is None
+        table.drop(bind)
+        assert canary.state == 'before-drop'
+
+    def test_table_drop_after(self):
+        table, bind = self.table, self.bind
+        canary = self.Canary(table, bind)
+        table.ddl_listeners['after-drop'].append(canary.after_drop)
+
+        table.create(bind)
+        assert canary.state is None
+        canary.state = 'skipped'
+        table.drop(bind)
+        assert canary.state == 'after-drop'
+
+    def test_table_drop_both(self):
+        table, bind = self.table, self.bind
+        canary = self.Canary(table, bind)
+        table.ddl_listeners['before-drop'].append(canary.before_drop)
+        table.ddl_listeners['after-drop'].append(canary.after_drop)
+
+        table.create(bind)
+        assert canary.state is None
+        table.drop(bind)
+        assert canary.state == 'after-drop'
+
+    def test_table_all(self):
+        table, bind = self.table, self.bind
+        canary = self.Canary(table, bind)
+        table.ddl_listeners['before-create'].append(canary.before_create)
+        table.ddl_listeners['after-create'].append(canary.after_create)
+        table.ddl_listeners['before-drop'].append(canary.before_drop)
+        table.ddl_listeners['after-drop'].append(canary.after_drop)
+
+        assert canary.state is None
+        table.create(bind)
+        assert canary.state == 'after-create'
+        canary.state = None
+        table.drop(bind)
+        assert canary.state == 'after-drop'
+
+    def test_table_create_before(self):
+        metadata, bind = self.metadata, self.bind
+        canary = self.Canary(metadata, bind)
+        metadata.ddl_listeners['before-create'].append(canary.before_create)
+
+        metadata.create_all(bind)
+        assert canary.state == 'before-create'
+        metadata.drop_all(bind)
+        assert canary.state == 'before-create'
+
+    def test_metadata_create_after(self):
+        metadata, bind = self.metadata, self.bind
+        canary = self.Canary(metadata, bind)
+        metadata.ddl_listeners['after-create'].append(canary.after_create)
+
+        canary.state = 'skipped'
+        metadata.create_all(bind)
+        assert canary.state == 'after-create'
+        metadata.drop_all(bind)
+        assert canary.state == 'after-create'
+
+    def test_metadata_create_both(self):
+        metadata, bind = self.metadata, self.bind
+        canary = self.Canary(metadata, bind)
+        metadata.ddl_listeners['before-create'].append(canary.before_create)
+        metadata.ddl_listeners['after-create'].append(canary.after_create)
+
+        metadata.create_all(bind)
+        assert canary.state == 'after-create'
+        metadata.drop_all(bind)
+        assert canary.state == 'after-create'
+
+    @testing.future
+    def test_metadata_table_isolation(self):
+        metadata, table, bind = self.metadata, self.table, self.bind
+
+        table_canary = self.Canary(table, bind)
+        table.ddl_listeners['before-create'].append(table_canary.before_create)
+
+        metadata_canary = self.Canary(metadata, bind)
+        metadata.ddl_listeners['before-create'].append(metadata_canary.before_create)
+
+        # currently, table.create() routes through the same execution
+        # path that metadata.create_all() does
+        self.table.create(self.bind)
+        assert metadata_canary.state == None
+
+    def test_append_listener(self):
+        metadata, table, bind = self.metadata, self.table, self.bind
+
+        fn = lambda *a: None
+
+        table.append_ddl_listener('before-create', fn)
+        self.assertRaises(LookupError, table.append_ddl_listener, 'blah', fn)
+
+        metadata.append_ddl_listener('before-create', fn)
+        self.assertRaises(LookupError, metadata.append_ddl_listener, 'blah', fn)
+
+
+class DDLExecutionTest(PersistTest):
+    def mock_engine(self):
+        buffer = []
+        def executor(sql, *a, **kw):
+            buffer.append(sql)
+        engine = create_engine(testing.db.name + '://',
+                               strategy='mock', executor=executor)
+        assert not hasattr(engine, 'mock')
+        engine.mock = buffer
+        return engine
+
+    def setUp(self):
+        self.engine = self.mock_engine()
+        self.metadata = MetaData(self.engine)
+        self.users = Table('users', self.metadata,
+                           Column('user_id', Integer, primary_key=True),
+                           Column('user_name', String(40)),
+                           )
+
+    def test_table_standalone(self):
+        users, engine = self.users, self.engine
+        DDL('mxyzptlk').execute_at('before-create', users)
+        DDL('klptzyxm').execute_at('after-create', users)
+        DDL('xyzzy').execute_at('before-drop', users)
+        DDL('fnord').execute_at('after-drop', users)
+
+        users.create()
+        strings = [str(x) for x in engine.mock]
+        assert 'mxyzptlk' in strings
+        assert 'klptzyxm' in strings
+        assert 'xyzzy' not in strings
+        assert 'fnord' not in strings
+        del engine.mock[:]
+        users.drop()
+        strings = [str(x) for x in engine.mock]
+        assert 'mxyzptlk' not in strings
+        assert 'klptzyxm' not in strings
+        assert 'xyzzy' in strings
+        assert 'fnord' in strings
+
+    def test_table_by_metadata(self):
+        metadata, users, engine = self.metadata, self.users, self.engine
+        DDL('mxyzptlk').execute_at('before-create', users)
+        DDL('klptzyxm').execute_at('after-create', users)
+        DDL('xyzzy').execute_at('before-drop', users)
+        DDL('fnord').execute_at('after-drop', users)
+
+        metadata.create_all()
+        strings = [str(x) for x in engine.mock]
+        assert 'mxyzptlk' in strings
+        assert 'klptzyxm' in strings
+        assert 'xyzzy' not in strings
+        assert 'fnord' not in strings
+        del engine.mock[:]
+        metadata.drop_all()
+        strings = [str(x) for x in engine.mock]
+        assert 'mxyzptlk' not in strings
+        assert 'klptzyxm' not in strings
+        assert 'xyzzy' in strings
+        assert 'fnord' in strings
+
+    def test_metadata(self):
+        metadata, engine = self.metadata, self.engine
+        DDL('mxyzptlk').execute_at('before-create', metadata)
+        DDL('klptzyxm').execute_at('after-create', metadata)
+        DDL('xyzzy').execute_at('before-drop', metadata)
+        DDL('fnord').execute_at('after-drop', metadata)
+
+        metadata.create_all()
+        strings = [str(x) for x in engine.mock]
+        assert 'mxyzptlk' in strings
+        assert 'klptzyxm' in strings
+        assert 'xyzzy' not in strings
+        assert 'fnord' not in strings
+        del engine.mock[:]
+        metadata.drop_all()
+        strings = [str(x) for x in engine.mock]
+        assert 'mxyzptlk' not in strings
+        assert 'klptzyxm' not in strings
+        assert 'xyzzy' in strings
+        assert 'fnord' in strings
+
+    def test_ddl_execute(self):
+        engine = create_engine('sqlite:///')
+        cx = engine.connect()
+        table = self.users
+        ddl = DDL('SELECT 1')
+
+        for py in ('engine.execute(ddl)',
+                   'engine.execute(ddl, table)',
+                   'cx.execute(ddl)',
+                   'cx.execute(ddl, table)',
+                   'ddl.execute(engine)',
+                   'ddl.execute(engine, table)',
+                   'ddl.execute(cx)',
+                   'ddl.execute(cx, table)'):
+            r = eval(py)
+            assert list(r) == [(1,)], py
+
+class DDLTest(PersistTest):
+    def mock_engine(self):
+        executor = lambda *a, **kw: None
+        engine = create_engine(testing.db.name + '://',
+                               strategy='mock', executor=executor)
+        engine.dialect.identifier_preparer = \
+           sqlalchemy.sql.compiler.IdentifierPreparer(engine.dialect)
+        return engine
+
+    def test_tokens(self):
+        m = MetaData()
+        bind = self.mock_engine()
+        sane_alone = Table('t', m, Column('id', Integer))
+        sane_schema = Table('t', m, Column('id', Integer), schema='s')
+        insane_alone = Table('t t', m, Column('id', Integer))
+        insane_schema = Table('t t', m, Column('id', Integer), schema='s s')
+
+        ddl = DDL('%(schema)s-%(table)s-%(fullname)s')
+
+        self.assertEquals(ddl._expand(sane_alone, bind), '-t-t')
+        self.assertEquals(ddl._expand(sane_schema, bind), '"s"-t-s.t')
+        self.assertEquals(ddl._expand(insane_alone, bind), '-"t t"-"t t"')
+        self.assertEquals(ddl._expand(insane_schema, bind),
+                          '"s s"-"t t"-"s s"."t t"')
+
+        # overrides are used piece-meal and verbatim.
+        ddl = DDL('%(schema)s-%(table)s-%(fullname)s-%(bonus)s',
+                  context={'schema':'S S', 'table': 'T T', 'bonus': 'b'})
+        self.assertEquals(ddl._expand(sane_alone, bind), 'S S-T T-t-b')
+        self.assertEquals(ddl._expand(sane_schema, bind), 'S S-T T-s.t-b')
+        self.assertEquals(ddl._expand(insane_alone, bind), 'S S-T T-"t t"-b')
+        self.assertEquals(ddl._expand(insane_schema, bind),
+                          'S S-T T-"s s"."t t"-b')
+    def test_filter(self):
+        cx = self.mock_engine()
+        cx.name = 'mock'
+
+        tbl = Table('t', MetaData(), Column('id', Integer))
+
+        assert DDL('')._should_execute('x', tbl, cx)
+        assert DDL('', on='mock')._should_execute('x', tbl, cx)
+        assert not DDL('', on='bogus')._should_execute('x', tbl, cx)
+        assert DDL('', on=lambda x,y,z: True)._should_execute('x', tbl, cx)
+        assert(DDL('', on=lambda x,y,z: z.engine.name != 'bogus').
+               _should_execute('x', tbl, cx))
+
+    def test_repr(self):
+        assert repr(DDL('s'))
+        assert repr(DDL('s', on='engine'))
+        assert repr(DDL('s', on=lambda x: 1))
+        assert repr(DDL('s', context={'a':1}))
+        assert repr(DDL('s', on='engine', context={'a':1}))
+
+
+if __name__ == "__main__":
+    testenv.main()