From: Mike Bayer Date: Wed, 30 Nov 2011 00:01:31 +0000 (-0500) Subject: implement server default, nullability for SQL server X-Git-Tag: rel_0_1_0~8 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=a0f0cd381dc5bb4b780fa239df277891b2236e34;p=thirdparty%2Fsqlalchemy%2Falembic.git implement server default, nullability for SQL server --- diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py index a3e6eff1..8c10fb3b 100644 --- a/alembic/ddl/base.py +++ b/alembic/ddl/base.py @@ -1,6 +1,8 @@ import functools from sqlalchemy.ext.compiler import compiles -from sqlalchemy.schema import DDLElement +from sqlalchemy.schema import DDLElement, Column +from sqlalchemy import Integer + from sqlalchemy import types as sqltypes class AlterTable(DDLElement): """Represent an ALTER TABLE statement. @@ -108,6 +110,12 @@ def visit_column_name(element, compiler, **kw): format_column_name(compiler, element.newname) ) +@compiles(ColumnDefault) +def visit_column_default(element, compiler, **kw): + raise NotImplementedError( + "Default compilation not implemented " + "for column default change") + def quote_dotted(name, quote): """quote the elements of a dotted name""" @@ -124,6 +132,11 @@ def format_table_name(compiler, name, schema): def format_column_name(compiler, name): return compiler.preparer.quote(name, None) +def format_server_default(compiler, default): +# if isinstance(default, basestring): +# default = DefaultClause(default) + return compiler.get_column_default_string(Column("x", Integer, server_default=default)) + def alter_table(compiler, name, schema): return "ALTER TABLE %s" % format_table_name(compiler, name, schema) diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 61596ea7..ef6ace71 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -205,7 +205,7 @@ class DefaultImpl(object): via :func:`.context.begin_transaction`. """ - self._exec("BEGIN") + self.static_output("BEGIN") def emit_commit(self): """Emit the string ``COMMIT``, or the backend-specific @@ -215,7 +215,7 @@ class DefaultImpl(object): via :func:`.context.begin_transaction`. """ - self._exec("COMMIT") + self.static_output("COMMIT") class _literal_bindparam(_BindParamClause): pass diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py index 66fddf7f..6490bd55 100644 --- a/alembic/ddl/mssql.py +++ b/alembic/ddl/mssql.py @@ -1,6 +1,7 @@ from alembic.ddl.impl import DefaultImpl from alembic.ddl.base import alter_table, AddColumn, ColumnName, \ - format_table_name, format_column_name, ColumnNullable, alter_column + format_table_name, format_column_name, ColumnNullable, alter_column,\ + format_server_default,ColumnDefault from alembic import util from sqlalchemy.ext.compiler import compiles @@ -28,7 +29,7 @@ class MSSQLImpl(DefaultImpl): self.static_output(self.batch_separator) def emit_begin(self): - self._exec("BEGIN TRANSACTION") + self.static_output("BEGIN TRANSACTION") def alter_column(self, table_name, column_name, nullable=None, @@ -56,15 +57,32 @@ class MSSQLImpl(DefaultImpl): super(MSSQLImpl, self).alter_column( table_name, column_name, nullable=nullable, - server_default=server_default, - name=name, type_=type_, schema=schema, existing_type=existing_type, - existing_server_default=existing_server_default, existing_nullable=existing_nullable ) + if server_default is not False: + if existing_server_default is not False or \ + server_default is None: + self._exec( + _exec_drop_col_constraint(self, + table_name, column_name, + 'sys.default_constraints') + ) + if server_default is not None: + super(MSSQLImpl, self).alter_column( + table_name, column_name, + schema=schema, + server_default=server_default) + + if name is not None: + super(MSSQLImpl, self).alter_column( + table_name, column_name, + schema=schema, + name=name) + def bulk_insert(self, table, rows): if self.as_sql: self._exec( @@ -133,6 +151,15 @@ def visit_column_nullable(element, compiler, **kw): "NULL" if element.nullable else "NOT NULL" ) +@compiles(ColumnDefault, 'mssql') +def visit_column_default(element, compiler, **kw): + # TODO: there can also be a named constraint + # with ADD CONSTRAINT here + return "%s ADD DEFAULT %s FOR %s" % ( + alter_table(compiler, element.table_name, element.schema), + format_server_default(compiler, element.default), + format_column_name(compiler, element.column_name) + ) @compiles(ColumnName, 'mssql') def visit_rename_column(element, compiler, **kw): diff --git a/tests/test_mssql.py b/tests/test_mssql.py index 9a4c8441..883590fd 100644 --- a/tests/test_mssql.py +++ b/tests/test_mssql.py @@ -118,6 +118,35 @@ class OpTest(TestCase): op.alter_column, "t", "c", nullable=False ) + def test_alter_add_server_default(self): + context = op_fixture('mssql') + op.alter_column("t", "c", server_default="5") + context.assert_( + "ALTER TABLE t ADD DEFAULT '5' FOR c" + ) + + def test_alter_replace_server_default(self): + context = op_fixture('mssql') + op.alter_column("t", "c", server_default="5", existing_server_default="6") + context.assert_contains("exec('alter table t drop constraint ' + @const_name_1)") + context.assert_contains( + "ALTER TABLE t ADD DEFAULT '5' FOR c" + ) + + def test_alter_remove_server_default(self): + context = op_fixture('mssql') + op.alter_column("t", "c", server_default=None) + context.assert_contains("exec('alter table t drop constraint ' + @const_name_1)") + + def test_alter_do_everything(self): + context = op_fixture('mssql') + op.alter_column("t", "c", name="c2", nullable=True, type_=Integer, server_default="5") + context.assert_( + 'ALTER TABLE t ALTER COLUMN c INTEGER NULL', + "ALTER TABLE t ADD DEFAULT '5' FOR c", + "EXEC sp_rename 't.c', 'c2', 'COLUMN'" + ) + # TODO: when we add schema support #def test_alter_column_rename_mssql_schema(self): # context = op_fixture('mssql')