]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
implement server default, nullability for SQL server
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 30 Nov 2011 00:01:31 +0000 (19:01 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 30 Nov 2011 00:01:31 +0000 (19:01 -0500)
alembic/ddl/base.py
alembic/ddl/impl.py
alembic/ddl/mssql.py
tests/test_mssql.py

index a3e6eff1a7dd43ba8b383488f7c9d0d1366fcce2..8c10fb3bf137d65f5cbfd9d588c0a45633f19255 100644 (file)
@@ -1,6 +1,8 @@
 import functools
 from sqlalchemy.ext.compiler import compiles
-from sqlalchemy.schema import DDLElement
+from sqlalchemy.schema import DDLElement, Column
+from sqlalchemy import Integer
+
 from sqlalchemy import types as sqltypes
 class AlterTable(DDLElement):
     """Represent an ALTER TABLE statement.
@@ -108,6 +110,12 @@ def visit_column_name(element, compiler, **kw):
         format_column_name(compiler, element.newname)
     )
 
+@compiles(ColumnDefault)
+def visit_column_default(element, compiler, **kw):
+    raise NotImplementedError(
+            "Default compilation not implemented "
+            "for column default change")
+
 def quote_dotted(name, quote):
     """quote the elements of a dotted name"""
 
@@ -124,6 +132,11 @@ def format_table_name(compiler, name, schema):
 def format_column_name(compiler, name):
     return compiler.preparer.quote(name, None)
 
+def format_server_default(compiler, default):
+#    if isinstance(default, basestring):
+#        default = DefaultClause(default)
+    return compiler.get_column_default_string(Column("x", Integer, server_default=default))
+
 def alter_table(compiler, name, schema):
     return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
 
index 61596ea73d02bde138b825a559e52400455b73d6..ef6ace714dc363a5ee76e09689bfd8c138420057 100644 (file)
@@ -205,7 +205,7 @@ class DefaultImpl(object):
         via :func:`.context.begin_transaction`.
         
         """
-        self._exec("BEGIN")
+        self.static_output("BEGIN")
 
     def emit_commit(self):
         """Emit the string ``COMMIT``, or the backend-specific
@@ -215,7 +215,7 @@ class DefaultImpl(object):
         via :func:`.context.begin_transaction`.
         
         """
-        self._exec("COMMIT")
+        self.static_output("COMMIT")
 
 class _literal_bindparam(_BindParamClause):
     pass
index 66fddf7f074231cfec055b42e6727558e7ced98a..6490bd550e64db88d363963129caa76fe043e7da 100644 (file)
@@ -1,6 +1,7 @@
 from alembic.ddl.impl import DefaultImpl
 from alembic.ddl.base import alter_table, AddColumn, ColumnName, \
-    format_table_name, format_column_name, ColumnNullable, alter_column
+    format_table_name, format_column_name, ColumnNullable, alter_column,\
+    format_server_default,ColumnDefault
 from alembic import util
 from sqlalchemy.ext.compiler import compiles
 
@@ -28,7 +29,7 @@ class MSSQLImpl(DefaultImpl):
             self.static_output(self.batch_separator)
 
     def emit_begin(self):
-        self._exec("BEGIN TRANSACTION")
+        self.static_output("BEGIN TRANSACTION")
 
     def alter_column(self, table_name, column_name, 
                         nullable=None,
@@ -56,15 +57,32 @@ class MSSQLImpl(DefaultImpl):
         super(MSSQLImpl, self).alter_column(
                         table_name, column_name, 
                         nullable=nullable,
-                        server_default=server_default,
-                        name=name,
                         type_=type_,
                         schema=schema,
                         existing_type=existing_type,
-                        existing_server_default=existing_server_default,
                         existing_nullable=existing_nullable
         )
 
+        if server_default is not False:
+            if existing_server_default is not False or \
+                server_default is None:
+                self._exec(
+                    _exec_drop_col_constraint(self, 
+                            table_name, column_name, 
+                            'sys.default_constraints')
+                )
+            if server_default is not None:
+                super(MSSQLImpl, self).alter_column(
+                                table_name, column_name, 
+                                schema=schema,
+                                server_default=server_default)
+
+        if name is not None:
+            super(MSSQLImpl, self).alter_column(
+                                table_name, column_name, 
+                                schema=schema,
+                                name=name)
+
     def bulk_insert(self, table, rows):
         if self.as_sql:
             self._exec(
@@ -133,6 +151,15 @@ def visit_column_nullable(element, compiler, **kw):
         "NULL" if element.nullable else "NOT NULL"
     )
 
+@compiles(ColumnDefault, 'mssql')
+def visit_column_default(element, compiler, **kw):
+    # TODO: there can also be a named constraint
+    # with ADD CONSTRAINT here
+    return "%s ADD DEFAULT %s FOR %s" % (
+        alter_table(compiler, element.table_name, element.schema),
+        format_server_default(compiler, element.default),
+        format_column_name(compiler, element.column_name)
+    )
 
 @compiles(ColumnName, 'mssql')
 def visit_rename_column(element, compiler, **kw):
index 9a4c8441e6d5c300a2def488a604768c9f619c5c..883590fdf48481ca4e7caa1266bee01dba044278 100644 (file)
@@ -118,6 +118,35 @@ class OpTest(TestCase):
             op.alter_column, "t", "c", nullable=False
         )
 
+    def test_alter_add_server_default(self):
+        context = op_fixture('mssql')
+        op.alter_column("t", "c", server_default="5")
+        context.assert_(
+            "ALTER TABLE t ADD DEFAULT '5' FOR c"
+        )
+
+    def test_alter_replace_server_default(self):
+        context = op_fixture('mssql')
+        op.alter_column("t", "c", server_default="5", existing_server_default="6")
+        context.assert_contains("exec('alter table t drop constraint ' + @const_name_1)")
+        context.assert_contains(
+            "ALTER TABLE t ADD DEFAULT '5' FOR c"
+        )
+
+    def test_alter_remove_server_default(self):
+        context = op_fixture('mssql')
+        op.alter_column("t", "c", server_default=None)
+        context.assert_contains("exec('alter table t drop constraint ' + @const_name_1)")
+
+    def test_alter_do_everything(self):
+        context = op_fixture('mssql')
+        op.alter_column("t", "c", name="c2", nullable=True, type_=Integer, server_default="5")
+        context.assert_(
+            'ALTER TABLE t ALTER COLUMN c INTEGER NULL', 
+            "ALTER TABLE t ADD DEFAULT '5' FOR c", 
+            "EXEC sp_rename 't.c', 'c2', 'COLUMN'"
+        )
+
     # TODO: when we add schema support
     #def test_alter_column_rename_mssql_schema(self):
     #    context = op_fixture('mssql')