From: Mike Bayer Date: Wed, 30 Nov 2011 00:21:58 +0000 (-0500) Subject: - add alter col default for PG/base X-Git-Tag: rel_0_1_0~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2f55f3a54319ab144a0d98287bf8b7755d7282a4;p=thirdparty%2Fsqlalchemy%2Falembic.git - add alter col default for PG/base - i want the ; after BEGIN/COMMIT for static generation, makes it easier to parse --- diff --git a/alembic/context.py b/alembic/context.py index c1f489af..60081ba8 100644 --- a/alembic/context.py +++ b/alembic/context.py @@ -526,6 +526,17 @@ def execute(sql): """ get_context().execute(sql) +def static_output(text): + """Emit text directly to the "offline" SQL stream. + + Typically this is for emitting comments that + start with --. The statement is not treated + as a SQL execution, no ; or batch separator + is added, etc. + + """ + get_context().impl.static_output(text) + def begin_transaction(): """Return a context manager that will enclose an operation within a "transaction", diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py index 8c10fb3b..867d92f4 100644 --- a/alembic/ddl/base.py +++ b/alembic/ddl/base.py @@ -99,7 +99,7 @@ def visit_column_type(element, compiler, **kw): return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), - "TYPE %s" % compiler.dialect.type_compiler.process(element.type_) + "TYPE %s" % format_type(compiler, element.type_) ) @compiles(ColumnName) @@ -112,9 +112,14 @@ def visit_column_name(element, compiler, **kw): @compiles(ColumnDefault) def visit_column_default(element, compiler, **kw): - raise NotImplementedError( - "Default compilation not implemented " - "for column default change") + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + "SET DEFAULT %s" % + format_server_default(compiler, element.default) + if element.default is not None + else "DROP DEFAULT" + ) def quote_dotted(name, quote): """quote the elements of a dotted name""" @@ -133,9 +138,12 @@ def format_column_name(compiler, name): return compiler.preparer.quote(name, None) def format_server_default(compiler, default): -# if isinstance(default, basestring): -# default = DefaultClause(default) - return compiler.get_column_default_string(Column("x", Integer, server_default=default)) + return compiler.get_column_default_string( + Column("x", Integer, server_default=default) + ) + +def format_type(compiler, type_): + return compiler.dialect.type_compiler.process(type_) def alter_table(compiler, name, schema): return "ALTER TABLE %s" % format_table_name(compiler, name, schema) diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index ef6ace71..74cc94bb 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -205,7 +205,7 @@ class DefaultImpl(object): via :func:`.context.begin_transaction`. """ - self.static_output("BEGIN") + self.static_output("BEGIN;") def emit_commit(self): """Emit the string ``COMMIT``, or the backend-specific @@ -215,7 +215,7 @@ class DefaultImpl(object): via :func:`.context.begin_transaction`. """ - self.static_output("COMMIT") + self.static_output("COMMIT;") class _literal_bindparam(_BindParamClause): pass diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py index 6490bd55..fc4c0f3f 100644 --- a/alembic/ddl/mssql.py +++ b/alembic/ddl/mssql.py @@ -1,7 +1,7 @@ from alembic.ddl.impl import DefaultImpl from alembic.ddl.base import alter_table, AddColumn, ColumnName, \ format_table_name, format_column_name, ColumnNullable, alter_column,\ - format_server_default,ColumnDefault + format_server_default,ColumnDefault, format_type from alembic import util from sqlalchemy.ext.compiler import compiles @@ -29,7 +29,7 @@ class MSSQLImpl(DefaultImpl): self.static_output(self.batch_separator) def emit_begin(self): - self.static_output("BEGIN TRANSACTION") + self.static_output("BEGIN TRANSACTION;") def alter_column(self, table_name, column_name, nullable=None, @@ -147,7 +147,7 @@ def visit_column_nullable(element, compiler, **kw): return "%s %s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), - compiler.dialect.type_compiler.process(element.existing_type), + format_type(compiler, element.existing_type), "NULL" if element.nullable else "NOT NULL" ) diff --git a/tests/test_mssql.py b/tests/test_mssql.py index 883590fd..bdf8272d 100644 --- a/tests/test_mssql.py +++ b/tests/test_mssql.py @@ -26,8 +26,8 @@ class FullEnvironmentTests(TestCase): def test_begin_comit(self): with capture_context_buffer(transactional_ddl=True) as buf: command.upgrade(self.cfg, self.a, sql=True) - assert "BEGIN TRANSACTION" in buf.getvalue() - assert "COMMIT" in buf.getvalue() + assert "BEGIN TRANSACTION;" in buf.getvalue() + assert "COMMIT;" in buf.getvalue() def test_batch_separator_default(self): with capture_context_buffer() as buf: diff --git a/tests/test_op.py b/tests/test_op.py index d69ec443..bc20d80e 100644 --- a/tests/test_op.py +++ b/tests/test_op.py @@ -98,6 +98,28 @@ def test_alter_column_type(): 'ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(50)' ) +def test_alter_column_set_default(): + context = op_fixture() + op.alter_column("t", "c", server_default="q") + context.assert_( + "ALTER TABLE t ALTER COLUMN c SET DEFAULT 'q'" + ) + +def test_alter_column_set_compiled_default(): + context = op_fixture() + op.alter_column("t", "c", server_default=func.utc_thing(func.current_timestamp())) + context.assert_( + "ALTER TABLE t ALTER COLUMN c SET DEFAULT utc_thing(CURRENT_TIMESTAMP)" + ) + +def test_alter_column_drop_default(): + context = op_fixture() + op.alter_column("t", "c", server_default=None) + context.assert_( + 'ALTER TABLE t ALTER COLUMN c DROP DEFAULT' + ) + + def test_alter_column_schema_type_unnamed(): context = op_fixture('mssql') op.alter_column("t", "c", type_=Boolean()) diff --git a/tests/test_sql_script.py b/tests/test_sql_script.py index a615bdf1..e127cb73 100644 --- a/tests/test_sql_script.py +++ b/tests/test_sql_script.py @@ -17,13 +17,13 @@ def teardown(): def test_begin_comit(): with capture_context_buffer(transactional_ddl=True) as buf: command.upgrade(cfg, a, sql=True) - assert "BEGIN" in buf.getvalue() - assert "COMMIT" in buf.getvalue() + assert "BEGIN;" in buf.getvalue() + assert "COMMIT;" in buf.getvalue() with capture_context_buffer(transactional_ddl=False) as buf: command.upgrade(cfg, a, sql=True) - assert "BEGIN" not in buf.getvalue() - assert "COMMIT" not in buf.getvalue() + assert "BEGIN;" not in buf.getvalue() + assert "COMMIT;" not in buf.getvalue() def test_version_from_none_insert(): with capture_context_buffer() as buf: