]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- move comparison of types, server default to the context.
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 28 Nov 2011 18:49:58 +0000 (13:49 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 28 Nov 2011 18:49:58 +0000 (13:49 -0500)
PG context in particular does some tricks to help these.
- But since type/default comparison is still loaded with surprises,
particularly the MySQL/MSSQL TINYINT/BIT-> boolean thing which we
can work around but haven't yet, turn both off by default.  They
aren't super useful compared to the huge number of wrong results
they can currently emit.
- Also add a plugin system for type/server default comparison.
- everything works but we're coding way ahead of tests at this
point

alembic/autogenerate.py
alembic/context.py
alembic/ddl/impl.py
alembic/ddl/postgresql.py
docs/build/tutorial.rst
tests/test_autogenerate.py

index 62462b25b387dc46ff9074f00125c28be0cfa4a4..b6811fabb3f9bde0c438657726e347e932e5d47d 100644 (file)
@@ -1,10 +1,10 @@
 """Provide the 'autogenerate' feature which can produce migration operations
 automatically."""
 
-from alembic.context import _context_opts, get_bind
+from alembic.context import _context_opts, get_bind, get_context
 from alembic import util
 from sqlalchemy.engine.reflection import Inspector
-from sqlalchemy import types as sqltypes, schema
+from sqlalchemy import schema, types as sqltypes
 import re
 
 import logging
@@ -22,16 +22,19 @@ def produce_migration_diffs(template_args, imports):
                 "a MetaData object to the context.")
     connection = get_bind()
     diffs = []
-    _produce_net_changes(connection, metadata, diffs)
-    _set_upgrade(template_args, _indent(_produce_upgrade_commands(diffs, imports)))
-    _set_downgrade(template_args, _indent(_produce_downgrade_commands(diffs, imports)))
+    autogen_context = {
+        'imports':imports,
+        'connection':connection,
+        'dialect':connection.dialect,
+        'context':get_context()
+    }
+    _produce_net_changes(connection, metadata, diffs, autogen_context)
+    template_args[_context_opts['upgrade_token']] = \
+            _indent(_produce_upgrade_commands(diffs, autogen_context))
+    template_args[_context_opts['downgrade_token']] = \
+            _indent(_produce_downgrade_commands(diffs, autogen_context))
     template_args['imports'] = "\n".join(sorted(imports))
 
-def _set_upgrade(template_args, text):
-    template_args[_context_opts['upgrade_token']] = text
-
-def _set_downgrade(template_args, text):
-    template_args[_context_opts['downgrade_token']] = text
 
 def _indent(text):
     text = "### commands auto generated by Alembic - please adjust! ###\n" + text
@@ -42,7 +45,7 @@ def _indent(text):
 ###################################################
 # walk structures
 
-def _produce_net_changes(connection, metadata, diffs):
+def _produce_net_changes(connection, metadata, diffs, autogen_context):
     inspector = Inspector.from_engine(connection)
     # TODO: not hardcode alembic_version here ?
     conn_table_names = set(inspector.get_table_names()).\
@@ -76,7 +79,7 @@ def _produce_net_changes(connection, metadata, diffs):
         _compare_columns(tname, 
                 conn_column_info[tname], 
                 metadata.tables[tname],
-                diffs)
+                diffs, autogen_context)
 
     # TODO: 
     # index add/drop
@@ -86,7 +89,7 @@ def _produce_net_changes(connection, metadata, diffs):
 ###################################################
 # element comparison
 
-def _compare_columns(tname, conn_table, metadata_table, diffs):
+def _compare_columns(tname, conn_table, metadata_table, diffs, autogen_context):
     metadata_cols_by_name = dict((c.name, c) for c in metadata_table.c)
     conn_col_names = set(conn_table)
     metadata_col_names = set(metadata_cols_by_name)
@@ -114,24 +117,25 @@ def _compare_columns(tname, conn_table, metadata_table, diffs):
         col_diff = []
         _compare_type(tname, colname,
             conn_col,
-            metadata_col.type,
-            col_diff
+            metadata_col,
+            col_diff, autogen_context
         )
         _compare_nullable(tname, colname,
             conn_col,
             metadata_col.nullable,
-            col_diff
+            col_diff, autogen_context
         )
         _compare_server_default(tname, colname,
             conn_col,
-            metadata_col.server_default,
-            col_diff
+            metadata_col,
+            col_diff, autogen_context
         )
         if col_diff:
             diffs.append(col_diff)
 
 def _compare_nullable(tname, cname, conn_col, 
-                            metadata_col_nullable, diffs):
+                            metadata_col_nullable, diffs, 
+                            autogen_context):
     conn_col_nullable = conn_col['nullable']
     if conn_col_nullable is not metadata_col_nullable:
         diffs.append(
@@ -149,16 +153,23 @@ def _compare_nullable(tname, cname, conn_col,
             cname
         )
 
-def _compare_type(tname, cname, conn_col, metadata_type, diffs):
+def _compare_type(tname, cname, conn_col, 
+                            metadata_col, diffs, 
+                            autogen_context):
+
     conn_type = conn_col['type']
-    if conn_type._compare_type_affinity(metadata_type):
-        comparator = _type_comparators.get(conn_type._type_affinity, None)
+    metadata_type = metadata_col.type
+    if conn_type._type_affinity is sqltypes.NullType:
+        log.info("Couldn't determine database type for column '%s.%s'" % (tname, cname))
+        return
+    if metadata_type._type_affinity is sqltypes.NullType:
+        log.info("Column '%s.%s' has no type within the model; can't compare" % (tname, cname))
+        return
 
-        isdiff = comparator and comparator(metadata_type, conn_type)
-    else:
-        isdiff = True
+    isdiff = autogen_context['context'].compare_type(conn_col, metadata_col)
 
     if isdiff:
+
         diffs.append(
             ("modify_type", tname, cname, 
                     {
@@ -172,10 +183,20 @@ def _compare_type(tname, cname, conn_col, metadata_type, diffs):
             conn_type, metadata_type, tname, cname
         )
 
-def _compare_server_default(tname, cname, conn_col, metadata_default, diffs):
+def _compare_server_default(tname, cname, conn_col, metadata_col, 
+                                diffs, autogen_context):
+
+    metadata_default = metadata_col.server_default
     conn_col_default = conn_col['default']
-    rendered_metadata_default = _render_server_default(metadata_default)
-    if conn_col_default != rendered_metadata_default:
+    if conn_col_default is None and metadata_default is None:
+        return False
+    rendered_metadata_default = _render_server_default(metadata_default, autogen_context)
+    isdiff = autogen_context['context'].compare_server_default(
+                        conn_col, metadata_col,
+                        rendered_metadata_default
+                    )
+    if isdiff:
+        conn_col_default = conn_col['default']
         diffs.append(
             ("modify_default", tname, cname, 
                 {
@@ -190,52 +211,33 @@ def _compare_server_default(tname, cname, conn_col, metadata_default, diffs):
             cname
         )
 
-def _string_compare(t1, t2):
-    return \
-        t1.length is not None and \
-        t1.length != t2.length
-
-def _numeric_compare(t1, t2):
-    return \
-        (
-            t1.precision is not None and \
-            t1.precision != t2.precision
-        ) or \
-        (
-            t1.scale is not None and \
-            t1.scale != t2.scale
-        )
-_type_comparators = {
-    sqltypes.String:_string_compare,
-    sqltypes.Numeric:_numeric_compare
-}
 
 ###################################################
 # produce command structure
 
-def _produce_upgrade_commands(diffs, imports):
+def _produce_upgrade_commands(diffs, autogen_context):
     buf = []
     for diff in diffs:
-        buf.append(_invoke_command("upgrade", diff, imports))
+        buf.append(_invoke_command("upgrade", diff, autogen_context))
     return "\n".join(buf)
 
-def _produce_downgrade_commands(diffs, imports):
+def _produce_downgrade_commands(diffs, autogen_context):
     buf = []
     for diff in diffs:
-        buf.append(_invoke_command("downgrade", diff, imports))
+        buf.append(_invoke_command("downgrade", diff, autogen_context))
     return "\n".join(buf)
 
-def _invoke_command(updown, args, imports):
+def _invoke_command(updown, args, autogen_context):
     if isinstance(args, tuple):
-        return _invoke_adddrop_command(updown, args, imports)
+        return _invoke_adddrop_command(updown, args, autogen_context)
     else:
-        return _invoke_modify_command(updown, args, imports)
+        return _invoke_modify_command(updown, args, autogen_context)
 
-def _invoke_adddrop_command(updown, args, imports):
+def _invoke_adddrop_command(updown, args, autogen_context):
     cmd_type = args[0]
     adddrop, cmd_type = cmd_type.split("_")
 
-    cmd_args = args[1:] + (imports,)
+    cmd_args = args[1:] + (autogen_context,)
 
     _commands = {
         "table":(_drop_table, _add_table),
@@ -253,7 +255,7 @@ def _invoke_adddrop_command(updown, args, imports):
     else:
         return cmd_callables[0](*cmd_args)
 
-def _invoke_modify_command(updown, args, imports):
+def _invoke_modify_command(updown, args, autogen_context):
     tname, cname = args[0][1:3]
     kw = {}
 
@@ -281,16 +283,16 @@ def _invoke_modify_command(updown, args, imports):
         kw.pop("existing_nullable", None)
     if "server_default" in kw:
         kw.pop("existing_server_default", None)
-    return _modify_col(tname, cname, imports, **kw)
+    return _modify_col(tname, cname, autogen_context, **kw)
 
 ###################################################
 # render python
 
-def _add_table(table, imports):
+def _add_table(table, autogen_context):
     return "create_table(%(tablename)r,\n%(args)s\n)" % {
         'tablename':table.name,
         'args':',\n'.join(
-            [_render_column(col, imports) for col in table.c] +
+            [_render_column(col, autogen_context) for col in table.c] +
             sorted([rcons for rcons in 
                 [_render_constraint(cons) for cons in 
                     table.constraints]
@@ -299,19 +301,19 @@ def _add_table(table, imports):
         ),
     }
 
-def _drop_table(table, imports):
+def _drop_table(table, autogen_context):
     return "drop_table(%r)" % table.name
 
-def _add_column(tname, column, imports):
+def _add_column(tname, column, autogen_context):
     return "add_column(%r, %s)" % (
             tname, 
-            _render_column(column, imports))
+            _render_column(column, autogen_context))
 
-def _drop_column(tname, column, imports):
+def _drop_column(tname, column, autogen_context):
     return "drop_column(%r, %r)" % (tname, column.name)
 
 def _modify_col(tname, cname, 
-                imports,
+                autogen_context,
                 server_default=False,
                 type_=None,
                 nullable=None,
@@ -322,12 +324,12 @@ def _modify_col(tname, cname,
     indent = " " * 11
     text = "alter_column(%r, %r" % (tname, cname)
     text += ", \n%sexisting_type=%s" % (indent, 
-                    _repr_type(prefix, existing_type, imports))
+                    _repr_type(prefix, existing_type, autogen_context))
     if server_default is not False:
         text += ", \n%sserver_default=%s" % (indent, 
-                        _render_server_default(server_default),)
+                        _render_server_default(server_default, autogen_context),)
     if type_ is not None:
-        text += ", \n%stype_=%s" % (indent, _repr_type(prefix, type_, imports))
+        text += ", \n%stype_=%s" % (indent, _repr_type(prefix, type_, autogen_context))
     if nullable is not None:
         text += ", \n%snullable=%r" % (
                         indent, nullable,)
@@ -338,19 +340,20 @@ def _modify_col(tname, cname,
         text += ", \n%sexisting_server_default=%s" % (
                         indent, 
                         _render_server_default(
-                            existing_server_default),
+                            existing_server_default, 
+                            autogen_context),
                     )
     text += ")"
     return text
 
 def _autogenerate_prefix():
-    return _context_opts['autogenerate_sqlalchemy_prefix']
+    return _context_opts['autogenerate_sqlalchemy_prefix'] or ''
 
-def _render_column(column, imports):
+def _render_column(column, autogen_context):
     opts = []
     if column.server_default:
         opts.append(("server_default", 
-                    _render_server_default(column.server_default)))
+                    _render_server_default(column.server_default, autogen_context)))
     if column.nullable is not None:
         opts.append(("nullable", column.nullable))
 
@@ -358,30 +361,32 @@ def _render_column(column, imports):
     return "%(prefix)sColumn(%(name)r, %(type)s, %(kw)s)" % {
         'prefix':_autogenerate_prefix(),
         'name':column.name,
-        'type':_repr_type(_autogenerate_prefix(), column.type, imports),
+        'type':_repr_type(_autogenerate_prefix(), column.type, autogen_context),
         'kw':", ".join(["%s=%s" % (kwname, val) for kwname, val in opts])
     }
 
-def _render_server_default(default):
+def _render_server_default(default, autogen_context):
     if isinstance(default, schema.DefaultClause):
         if isinstance(default.arg, basestring):
             default = default.arg
         else:
-            default = str(default.arg)
+            default = str(default.arg.compile(dialect=autogen_context['dialect']))
     if isinstance(default, basestring):
         # TODO: this is just a hack to get 
         # tests to pass until we figure out
         # WTF sqlite is doing
-        default = default.replace("'", "")
+        default = re.sub(r"^'|'$", "", default)
         return "'%s'" % default
     else:
         return None
 
-def _repr_type(prefix, type_, imports):
+def _repr_type(prefix, type_, autogen_context):
     mod = type(type_).__module__
+    imports = autogen_context.get('imports', None)
     if mod.startswith("sqlalchemy.dialects"):
         dname = re.match(r"sqlalchemy\.dialects\.(\w+)", mod).group(1)
-        imports.add("from sqlalchemy.dialects import %s" % dname)
+        if imports is not None:
+            imports.add("from sqlalchemy.dialects import %s" % dname)
         return "%s.%r" % (dname, type_)
     else:
         return "%s%r" % (prefix, type_)
index adba2300b7b87320fe8799e5d2f388d86ea19d8e..0ac9cb9ab1266cb1c5d8db890a246e705ad37fd7 100644 (file)
@@ -16,20 +16,22 @@ _version = Table('alembic_version', _meta,
 
 class Context(object):
     """Maintains state throughout the migration running process.
-    
+
     Mediates the relationship between an ``env.py`` environment script, 
     a :class:`.ScriptDirectory` instance, and a :class:`.DefaultImpl` instance.
 
     The :class:`.Context` is available directly via the :func:`.get_context` function,
     though usually it is referenced behind the scenes by the various module level functions
     within the :mod:`alembic.context` module.
-    
+
     """
     def __init__(self, dialect, script, connection, fn, 
                         as_sql=False, 
                         output_buffer=None,
                         transactional_ddl=None,
-                        starting_rev=None):
+                        starting_rev=None,
+                        compare_type=False,
+                        compare_server_default=False):
         self.dialect = dialect
         self.script = script
         if as_sql:
@@ -41,6 +43,9 @@ class Context(object):
         self.as_sql = as_sql
         self.output_buffer = output_buffer if output_buffer else sys.stdout
 
+        self._user_compare_type = compare_type
+        self._user_compare_server_default = compare_server_default
+
         self._start_from_rev = starting_rev
         self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
                             dialect, connection, self.as_sql,
@@ -116,7 +121,7 @@ class Context(object):
     @property
     def bind(self):
         """Return the current "bind".
-        
+
         In online mode, this is an instance of
         :class:`sqlalchemy.engine.base.Connection`, and is suitable
         for ad-hoc execution of any kind of usage described 
@@ -124,15 +129,58 @@ class Context(object):
         for usage with the :meth:`sqlalchemy.schema.Table.create`
         and :meth:`sqlalchemy.schema.MetaData.create_all` methods
         of :class:`.Table`, :class:`.MetaData`.
-        
+
         Note that when "standard output" mode is enabled, 
         this bind will be a "mock" connection handler that cannot
         return results and is only appropriate for a very limited
         subset of commands.
-        
+
         """
         return self.connection
 
+    def compare_type(self, inspector_column, metadata_column):
+        if self._user_compare_type is False:
+            return False
+
+        if callable(self._user_compare_type):
+            user_value = self._user_compare_type(
+                self,
+                inspector_column,
+                metadata_column,
+                inspector_column['type'],
+                metadata_column.type
+            )
+            if user_value is not None:
+                return user_value
+
+        return self.impl.compare_type(
+                                    inspector_column, 
+                                    metadata_column)
+
+    def compare_server_default(self, inspector_column, 
+                            metadata_column, 
+                            rendered_metadata_default):
+
+        if self._user_compare_server_default is False:
+            return False
+
+        if callable(self._user_compare_server_default):
+            user_value = self._user_compare_server_default(
+                    self,
+                    inspector_column,
+                    metadata_column,
+                    inspector_column['default'],
+                    metadata_column.server_default,
+                    rendered_metadata_default
+            )
+            if user_value is not None:
+                return user_value
+
+        return self.impl.compare_server_default(
+                                inspector_column, 
+                                metadata_column, 
+                                rendered_metadata_default)
+
 config = None
 """The current :class:`.Config` object.
 
@@ -151,9 +199,9 @@ _script = None
 def _opts(cfg, script, **kw):
     """Set up options that will be used by the :func:`.configure`
     function.
-    
+
     This basically sets some global variables.
-    
+
     """
     global config, _script
     _context_opts.update(kw)
@@ -168,27 +216,27 @@ def _clear():
 def is_offline_mode():
     """Return True if the current migrations environment 
     is running in "offline mode".
-    
+
     This is ``True`` or ``False`` depending 
     on the the ``--sql`` flag passed.
 
     This function does not require that the :class:`.Context` 
     has been configured.
-    
+
     """
     return _context_opts.get('as_sql', False)
 
 def is_transactional_ddl():
     """Return True if the context is configured to expect a
     transactional DDL capable backend.
-    
+
     This defaults to the type of database in use, and 
     can be overridden by the ``transactional_ddl`` argument
     to :func:`.configure`
-    
+
     This function requires that a :class:`.Context` has first been 
     made available via :func:`.configure`.
-    
+
     """
     return get_context().impl.transactional_ddl
 
@@ -200,14 +248,14 @@ def get_head_revision():
 
     This function does not require that the :class:`.Context` 
     has been configured.
-    
+
     """
     return _script._as_rev_number("head")
 
 def get_starting_revision_argument():
     """Return the 'starting revision' argument,
     if the revision was passed using ``start:end``.
-    
+
     This is only meaningful in "offline" mode.
     Returns ``None`` if no value is available
     or was configured.
@@ -228,7 +276,7 @@ def get_revision_argument():
 
     This is typically the argument passed to the 
     ``upgrade`` or ``downgrade`` command.
-    
+
     If it was specified as ``head``, the actual 
     version number is returned; if specified
     as ``base``, ``None`` is returned.
@@ -249,7 +297,7 @@ def get_tag_argument():
 
     This function does not require that the :class:`.Context` 
     has been configured.
-    
+
     """
     return _context_opts.get('tag', None)
 
@@ -262,27 +310,29 @@ def configure(
         starting_rev=None,
         tag=None,
         autogenerate_metadata=None,
+        compare_type=False,
+        compare_server_default=False,
         upgrade_token="upgrades",
         downgrade_token="downgrades",
         autogenerate_sqlalchemy_prefix="sa.",
     ):
     """Configure the migration environment.
-    
+
     The important thing needed here is first a way to figure out
     what kind of "dialect" is in use.   The second is to pass
     an actual database connection, if one is required.
-    
+
     If the :func:`.requires_connection` function returns False,
     then no connection is needed here.  Otherwise, the
     ``connection`` parameter should be present as an 
     instance of :class:`sqlalchemy.engine.base.Connection`.
-    
+
     This function is typically called from the ``env.py``
     script within a migration environment.  It can be called
     multiple times for an invocation.  The most recent :class:`~sqlalchemy.engine.base.Connection`
     for which it was called is the one that will be operated upon
     by the next call to :func:`.run_migrations`.
-    
+
     :param connection: a :class:`sqlalchemy.engine.base.Connection`.  The type of dialect
      to be used will be derived from this.
     :param url: a string database url, or a :class:`sqlalchemy.engine.url.URL` object.
@@ -306,6 +356,46 @@ def configure(
      "alembic revision" command.  The tables present will be compared against
      what is locally available on the target :class:`~sqlalchemy.engine.base.Connection`
      to produce candidate upgrade/downgrade operations.
+    :param compare_type: Indicates type comparison behavior during an autogenerate
+     operation.  Defaults to ``False`` which disables type comparison.  Set to 
+     ``True`` to turn on default type comparison, which has varied accuracy depending
+     on backend.
+     
+     To customize type comparison behavior, a callable may be specified which
+     can filter type comparisons during an autogenerate operation.   The format of 
+     this callable is::
+     
+        def my_compare_type(context, inspected_column, 
+                    metadata_column, inspected_type, metadata_type):
+            # return True if the types are different,
+            # False if not, or None to allow the default implementation
+            # to compare these types
+            pass
+
+     A return value of ``None`` indicates to allow default type comparison to
+     proceed.
+
+    :param compare_server_default: Indicates server default comparison behavior during 
+     an autogenerate operation.  Defaults to ``False`` which disables server default 
+     comparison.  Set to  ``True`` to turn on server default comparison, which has 
+     varied accuracy depending on backend.
+    
+     To customize server default comparison behavior, a callable may be specified
+     which can filter server default comparisons during an autogenerate operation.
+     defaults during an autogenerate operation.   The format of this callable is::
+     
+        def my_compare_server_default(context, inspected_column, 
+                    metadata_column, inspected_default, metadata_default,
+                    rendered_metadata_default):
+            # return True if the defaults are different,
+            # False if not, or None to allow the default implementation
+            # to compare these defaults
+            pass
+
+     A return value of ``None`` indicates to allow default server default comparison 
+     to proceed.  Note that some backends such as Postgresql actually execute
+     the two defaults on the database side to compare for equivalence.
+
     :param upgrade_token: when running "alembic revision" with the ``--autogenerate``
      option, the text of the candidate upgrade operations will be present in this
      template variable when script.py.mako is rendered.
@@ -315,7 +405,7 @@ def configure(
     :param autogenerate_sqlalchemy_prefix: When autogenerate refers to SQLAlchemy 
      :class:`~sqlalchemy.schema.Column` or type classes, this prefix will be used
      (i.e. ``sa.Column("somename", sa.Integer)``)
-     
+
     """
 
     if connection:
@@ -351,7 +441,9 @@ def configure(
                         as_sql=opts.get('as_sql', False), 
                         output_buffer=opts.get("output_buffer"),
                         transactional_ddl=opts.get("transactional_ddl"),
-                        starting_rev=opts.get("starting_rev")
+                        starting_rev=opts.get("starting_rev"),
+                        compare_type=compare_type,
+                        compare_server_default=compare_server_default,
                     )
 
 def configure_connection(connection):
@@ -374,7 +466,7 @@ def run_migrations(**kw):
 
     This function requires that a :class:`.Context` has first been 
     made available via :func:`.configure`.
-    
+
     """
     get_context().run_migrations(**kw)
 
@@ -388,19 +480,19 @@ def execute(sql):
 
     This function requires that a :class:`.Context` has first been 
     made available via :func:`.configure`.
-    
+
     """
     get_context().execute(sql)
 
 def get_context():
     """Return the current :class:`.Context` object.
-    
+
     If :func:`.configure` has not been called yet, raises
     an exception.
-    
+
     Generally, env.py scripts should access the module-level functions
     in :mod:`alebmic.context` to get at this object's functionality.
-    
+
     """
     if _context is None:
         raise Exception("No context has been configured yet.")
@@ -408,14 +500,14 @@ def get_context():
 
 def get_bind():
     """Return the current 'bind'.
-    
+
     In "online" mode, this is the 
     :class:`sqlalchemy.engine.Connection` currently being used
     to emit SQL to the database.
 
     This function requires that a :class:`.Context` has first been 
     made available via :func:`.configure`.
-    
+
     """
     return get_context().bind
 
index 83bd3ac3ddd662549d106acc06b8efba711e0f9e..2c6a666b15f3f6d448f83011ae1f3fec5e4f30a8 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy.sql.expression import _BindParamClause
 from sqlalchemy.ext.compiler import compiles
 from sqlalchemy import schema
 from alembic.ddl import base
+from sqlalchemy import types as sqltypes
 
 class ImplMeta(type):
     def __init__(cls, classname, bases, dict_):
@@ -145,6 +146,32 @@ class DefaultImpl(object):
         else:
             self._exec(table.insert(), *rows)
 
+    def compare_type(self, inspector_column, metadata_column):
+
+        conn_type = inspector_column['type']
+        metadata_type = metadata_column.type
+
+        metadata_impl = metadata_type.dialect_impl(self.dialect)
+
+        # work around SQLAlchemy bug "stale value for type affinity"
+        # fixed in 0.7.4
+        metadata_impl.__dict__.pop('_type_affinity', None)
+
+        if conn_type._compare_type_affinity(
+                            metadata_impl
+                        ):
+            comparator = _type_comparators.get(conn_type._type_affinity, None)
+
+            return comparator and comparator(metadata_type, conn_type)
+        else:
+            return True
+
+    def compare_server_default(self, inspector_column, 
+                            metadata_column, 
+                            rendered_metadata_default):
+        conn_col_default = inspector_column['default']
+        return conn_col_default != rendered_metadata_default
+
 
 class _literal_bindparam(_BindParamClause):
     pass
@@ -153,3 +180,27 @@ class _literal_bindparam(_BindParamClause):
 def _render_literal_bindparam(element, compiler, **kw):
     return compiler.render_literal_bindparam(element, **kw)
 
+
+def _string_compare(t1, t2):
+    return \
+        t1.length is not None and \
+        t1.length != t2.length
+
+def _numeric_compare(t1, t2):
+    return \
+        (
+            t1.precision is not None and \
+            t1.precision != t2.precision
+        ) or \
+        (
+            t1.scale is not None and \
+            t1.scale != t2.scale
+        )
+_type_comparators = {
+    sqltypes.String:_string_compare,
+    sqltypes.Numeric:_numeric_compare
+}
+
+
+
+
index f6268424ec8c077952672ac981fa7e0d5e9b088e..27bbe90e117974a1080fb3893a920f6ebc2807bb 100644 (file)
@@ -1,5 +1,28 @@
 from alembic.ddl.impl import DefaultImpl
+from sqlalchemy import types as sqltypes
+import re
 
 class PostgresqlImpl(DefaultImpl):
     __dialect__ = 'postgresql'
     transactional_ddl = True
+
+    def compare_server_default(self, inspector_column, 
+                            metadata_column, 
+                            rendered_metadata_default):
+
+        # don't do defaults for SERIAL columns
+        if metadata_column.primary_key and \
+            metadata_column is metadata_column.table._autoincrement_column:
+            return False
+
+        conn_col_default = inspector_column['default']
+
+        if metadata_column.type._type_affinity is not sqltypes.String:
+            rendered_metadata_default = re.sub(r"^'|'$", "", rendered_metadata_default)
+
+        return not self.connection.execute(
+            "SELECT %s = %s" % (
+                conn_col_default,
+                rendered_metadata_default
+            )
+        )
index 07f8b0e4f96ccf1faece7379db760076ac3f50d8..ce24423799f83fd4653778ce1d49e379010e4869 100644 (file)
@@ -457,12 +457,30 @@ is already present::
 The migration hasn't actually run yet, of course.  We do that via the usual ``upgrade``
 command.   We should also go into our migration file and alter it as needed, including 
 adjustments to the directives as well as the addition of other directives which these may
-be dependent on - specifically data changes in between creates/alters/drops.   The autogenerate
-feature can currently detect:
+be dependent on - specifically data changes in between creates/alters/drops.   
+
+Autogenerate will by default detect:
 
 * Table additions, removals.
-* Column additions, removals
-* Change of column type, nullable status, default value
+* Column additions, removals.
+* Change of nullable status on columns.
+
+Autogenerate can *optionally* detect:
+
+* Change of column type.  This will occur if you set ``compare_type=True``
+  on :func:`.context.configure`.  The feature works well in most cases,
+  but is off by default so that it can be tested on the target schema
+  first.  It can also be customized by passing a callable here; see the
+  function's documentation for details.
+* Change of server default.  This will occur if you set 
+  ``compare_server_default=True`` on :func:`.context.configure`.  
+  This feature works well for simple cases but cannot always produce 
+  accurate results.  The Postgresql backend will actually invoke 
+  the "detected" and "metadata" values against the database to 
+  determine equivalence.  The feature is off by default so that
+  it can be tested on the target schema first.  Like type comparison,
+  it can also be customized by passing a callable; see the
+  function's documentation for details.
 
 Autogenerate can *not* detect:
 
index bc8cfea0958b8a6111e0694f92dd8a1ac40267ab..694bec6443ac49241e352f36ec4897ebc1d20d6e 100644 (file)
@@ -67,6 +67,13 @@ class AutogenerateDiffTest(TestCase):
         cls.bind = sqlite_db()
         cls.m1 = _model_one()
         cls.m1.create_all(cls.bind)
+        cls.m2 = _model_two()
+        context.configure(
+            connection = cls.bind.connect(),
+            compare_type = True,
+            compare_server_default = True,
+            autogenerate_metadata=cls.m2
+        )
 
     @classmethod
     def teardown_class(cls):
@@ -75,11 +82,15 @@ class AutogenerateDiffTest(TestCase):
     def test_diffs(self):
         """test generation of diff rules"""
 
-        metadata = _model_two()
-        connection = self.bind.connect()
+        metadata = self.m2
+        connection = context.get_bind()
         diffs = []
-        autogenerate._produce_net_changes(connection, metadata, diffs)
-        print "\n".join(repr(d) for d in diffs)
+        autogenerate._produce_net_changes(connection, metadata, diffs, {
+                                            'imports':set(),
+                                            'connection':connection,
+                                            'dialect':connection.dialect,
+                                            'context':context.get_context()
+                                            })
 
         eq_(
             diffs[0],
@@ -126,13 +137,9 @@ class AutogenerateDiffTest(TestCase):
         # TODO: this test isn't going
         # to be so spectacular on Py3K...
 
-        metadata = _model_two()
-        connection = self.bind.connect()
+        metadata = self.m2
         template_args = {}
-        context.configure(
-            connection=connection, 
-            autogenerate_metadata=metadata)
-        autogenerate.produce_migration_diffs(template_args, set())
+        autogenerate.produce_migration_diffs(template_args, {})
         eq_(template_args['upgrades'],
 """### commands auto generated by Alembic - please adjust! ###
     create_table('item',
@@ -200,7 +207,7 @@ class AutogenRenderTest(TestCase):
             Column("amount", Numeric(5, 2)),
         )
         eq_ignore_whitespace(
-            autogenerate._add_table(t, set()),
+            autogenerate._add_table(t, {}),
             "create_table('test',"
             "sa.Column('id', sa.Integer(), nullable=False),"
             "sa.Column('address_id', sa.Integer(), nullable=True),"
@@ -215,14 +222,14 @@ class AutogenRenderTest(TestCase):
 
     def test_render_drop_table(self):
         eq_(
-            autogenerate._drop_table(Table("sometable", MetaData()), set()),
+            autogenerate._drop_table(Table("sometable", MetaData()), {}),
             "drop_table('sometable')"
         )
 
     def test_render_add_column(self):
         eq_(
             autogenerate._add_column(
-                    "foo", Column("x", Integer, server_default="5"), set()),
+                    "foo", Column("x", Integer, server_default="5"), {}),
             "add_column('foo', sa.Column('x', sa.Integer(), "
                 "server_default='5', nullable=True))"
         )
@@ -230,7 +237,7 @@ class AutogenRenderTest(TestCase):
     def test_render_drop_column(self):
         eq_(
             autogenerate._drop_column(
-                    "foo", Column("x", Integer, server_default="5"), set()),
+                    "foo", Column("x", Integer, server_default="5"), {}),
 
             "drop_column('foo', 'x')"
         )
@@ -239,7 +246,7 @@ class AutogenRenderTest(TestCase):
         eq_ignore_whitespace(
             autogenerate._modify_col(
                         "sometable", "somecolumn", 
-                        set(),
+                        {},
                         type_=CHAR(10), existing_type=CHAR(20)),
             "alter_column('sometable', 'somecolumn', "
                 "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10))"
@@ -249,7 +256,7 @@ class AutogenRenderTest(TestCase):
         eq_ignore_whitespace(
             autogenerate._modify_col(
                         "sometable", "somecolumn", 
-                        set(),
+                        {},
                         existing_type=Integer(),
                         nullable=True),
             "alter_column('sometable', 'somecolumn', "
@@ -260,7 +267,7 @@ class AutogenRenderTest(TestCase):
         eq_ignore_whitespace(
             autogenerate._modify_col(
                         "sometable", "somecolumn", 
-                        set(),
+                        {},
                         existing_type=Integer(),
                         existing_server_default="5",
                         nullable=True),