From fb47eda8ea08bc2ece4a2344eb87db115cdcb0d0 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 28 Nov 2011 22:29:43 -0500 Subject: [PATCH] - add version check for at least 06, tests for 07 in selected areas - add "requires 07" decorators to test suite - add tests for PG ENUM in offline mode. works in conjunction with the latest 0.7.4 tip of SQLAlchemy, fixes #9. Docs will be needed to illustrate how ENUM should be used. - add support for table before_create and after_create events within op.create_table(). Currently this will do the ENUM thing for PG but will also invoke any other kinds of events that might get configured on the table. --- alembic/command.py | 1 + alembic/context.py | 2 +- alembic/ddl/impl.py | 17 ++++++- alembic/op.py | 11 ++++- alembic/util.py | 17 ++++++- tests/__init__.py | 12 ++++- tests/test_autogenerate.py | 5 +- tests/test_postgresql.py | 98 ++++++++++++++++++++++++++++++++++---- tests/test_sql_script.py | 4 +- 9 files changed, 150 insertions(+), 17 deletions(-) diff --git a/alembic/command.py b/alembic/command.py index a3598ac7..f7075b18 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -66,6 +66,7 @@ def revision(config, message=None, autogenerate=False): template_args = {} imports = set() if autogenerate: + util.requires_07("autogenerate") def retrieve_migrations(rev): if script._get_rev(rev) is not script._get_rev("head"): raise util.CommandError("Target database is not up to date.") diff --git a/alembic/context.py b/alembic/context.py index 51a060e9..d9500e7b 100644 --- a/alembic/context.py +++ b/alembic/context.py @@ -48,7 +48,7 @@ class Context(object): self._start_from_rev = starting_rev self.impl = ddl.DefaultImpl.get_by_dialect(dialect)( - dialect, connection, self.as_sql, + dialect, self.connection, self.as_sql, transactional_ddl, self.output_buffer ) diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 2c6a666b..4159d526 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -3,6 +3,7 @@ from sqlalchemy.sql.expression import _BindParamClause from sqlalchemy.ext.compiler import compiles from sqlalchemy import schema from alembic.ddl import base +from alembic import util from sqlalchemy import types as sqltypes class ImplMeta(type): @@ -31,11 +32,13 @@ class DefaultImpl(object): transactional_ddl = False - def __init__(self, dialect, connection, as_sql, transactional_ddl, output_buffer): + def __init__(self, dialect, connection, as_sql, + transactional_ddl, output_buffer): self.dialect = dialect self.connection = connection self.as_sql = as_sql self.output_buffer = output_buffer + self.memo = {} if transactional_ddl is not None: self.transactional_ddl = transactional_ddl @@ -46,6 +49,10 @@ class DefaultImpl(object): def static_output(self, text): self.output_buffer.write(text + "\n\n") + @property + def bind(self): + return self.connection + def _exec(self, construct, *args, **kw): if isinstance(construct, basestring): construct = text(construct) @@ -123,7 +130,15 @@ class DefaultImpl(object): new_table_name, schema=schema)) def create_table(self, table): + if util.sqla_07: + table.dispatch.before_create(table, self.connection, + checkfirst=False, + _ddl_runner=self) self._exec(schema.CreateTable(table)) + if util.sqla_07: + table.dispatch.after_create(table, self.connection, + checkfirst=False, + _ddl_runner=self) for index in table.indexes: self._exec(schema.CreateIndex(index)) diff --git a/alembic/op.py b/alembic/op.py index 848e40ba..224003e7 100644 --- a/alembic/op.py +++ b/alembic/op.py @@ -401,8 +401,17 @@ def create_table(name, *columns, **kw): Column('description', NVARCHAR(200)) ) + :param name: Name of the table + :param \*columns: collection of :class:`~sqlalchemy.schema.Column` objects within + the table, as well as optional :class:`~sqlalchemy.schema.Constraint` objects + and :class:`~.sqlalchemy.schema.Index` objects. + :param emit_events: if ``True``, emit ``before_create`` and ``after_create`` + events when the table is being created. In particular, the Postgresql ENUM + type will emit a CREATE TYPE within these events. + :param \**kw: Other keyword arguments are passed to the underlying + :class:`.Table` object created for the command. + """ - get_impl().create_table( _table(name, *columns, **kw) ) diff --git a/alembic/util.py b/alembic/util.py index 82572fd1..aa5d0e25 100644 --- a/alembic/util.py +++ b/alembic/util.py @@ -10,10 +10,25 @@ import time import random import uuid - class CommandError(Exception): pass +from sqlalchemy import __version__ +_vers = tuple([int(x) for x in __version__.split(".")]) +sqla_06 = _vers > (0, 6) +sqla_07 = _vers > (0, 7) +if not sqla_06: + raise CommandError( + "SQLAlchemy 0.6 or greater is required. " + "Version 0.7 or above required for full featureset.") + +def requires_07(feature): + if not sqla_07: + raise CommandError( + "The %s feature requires " + "SQLAlchemy 0.7 or greater." + % feature + ) try: width = int(os.environ['COLUMNS']) except (KeyError, ValueError): diff --git a/tests/__init__.py b/tests/__init__.py index be4c48bb..4f81b580 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -13,6 +13,7 @@ from alembic.ddl.impl import _impls import ConfigParser from nose import SkipTest from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.util import decorator staging_directory = os.path.join(os.path.dirname(__file__), 'scratch') files_directory = os.path.join(os.path.dirname(__file__), 'files') @@ -47,6 +48,12 @@ def db_for_dialect(name): _engs[name] = eng return eng +@decorator +def requires_07(fn, *arg, **kw): + if not util.sqla_07: + raise SkipTest("SQLAlchemy 0.7 required") + return fn(*arg, **kw) + _dialects = {} def _get_dialect(name): if name is None or name == 'default': @@ -117,7 +124,10 @@ def op_fixture(dialect='default', as_sql=False): self.assertion = [] self.dialect = dialect self.as_sql = as_sql - + # TODO: this might need to + # be more like a real connection + # as tests get more involved + self.connection = None def _exec(self, construct, *args, **kw): if isinstance(construct, basestring): construct = text(construct) diff --git a/tests/test_autogenerate.py b/tests/test_autogenerate.py index 5260b04f..7e433665 100644 --- a/tests/test_autogenerate.py +++ b/tests/test_autogenerate.py @@ -3,7 +3,8 @@ from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \ from sqlalchemy.types import NULLTYPE from alembic import autogenerate, context from unittest import TestCase -from tests import staging_env, sqlite_db, clear_staging_env, eq_, eq_ignore_whitespace +from tests import staging_env, sqlite_db, clear_staging_env, eq_, \ + eq_ignore_whitespace, requires_07 def _model_one(): m = MetaData() @@ -63,6 +64,7 @@ def _model_two(): class AutogenerateDiffTest(TestCase): @classmethod + @requires_07 def setup_class(cls): staging_env() cls.bind = sqlite_db() @@ -220,6 +222,7 @@ class AutogenRenderTest(TestCase): """test individual directives""" @classmethod + @requires_07 def setup_class(cls): context._context_opts['sqlalchemy_module_prefix'] = 'sa.' diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 0f88f57e..325cbb61 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -1,9 +1,95 @@ -from tests import op_fixture, db_for_dialect, eq_, staging_env, clear_staging_env +from tests import op_fixture, db_for_dialect, eq_, staging_env, \ + clear_staging_env, no_sql_testing_config,\ + capture_context_buffer, requires_07 from unittest import TestCase from sqlalchemy import DateTime, MetaData, Table, Column, text, Integer, String from sqlalchemy.engine.reflection import Inspector -from alembic import context +from alembic import context, command, util +from alembic.script import ScriptDirectory + +class PGOfflineEnumTest(TestCase): + @requires_07 + def setUp(self): + env = staging_env() + self.cfg = cfg = no_sql_testing_config() + + self.rid = rid = util.rev_id() + + self.script = script = ScriptDirectory.from_config(cfg) + script.generate_rev(rid, None, refresh=True) + + def _inline_enum_script(self): + self.script.write(self.rid, """ +down_revision = None + +from alembic.op import * +from sqlalchemy.dialects.postgresql import ENUM +from sqlalchemy import Column + +def upgrade(): + create_table("sometable", + Column("data", ENUM("one", "two", "three", name="pgenum")) + ) + +def downgrade(): + drop_table("sometable") +""") + + def _distinct_enum_script(self): + self.script.write(self.rid, """ +down_revision = None + +from alembic.op import * +from sqlalchemy.dialects.postgresql import ENUM +from sqlalchemy import Column + +def upgrade(): + enum = ENUM("one", "two", "three", name="pgenum", create_type=False) + enum.create(get_bind(), checkfirst=False) + create_table("sometable", + Column("data", enum) + ) + +def downgrade(): + drop_table("sometable") + ENUM(name="pgenum").drop(get_bind(), checkfirst=False) + +""") + + def tearDown(self): + clear_staging_env() + + def test_offline_inline_enum_create(self): + self._inline_enum_script() + with capture_context_buffer() as buf: + command.upgrade(self.cfg, self.rid, sql=True) + assert "CREATE TYPE pgenum AS ENUM ('one','two','three')" in buf.getvalue() + assert "CREATE TABLE sometable (\n data pgenum\n)" in buf.getvalue() + + def test_offline_inline_enum_drop(self): + self._inline_enum_script() + with capture_context_buffer() as buf: + command.downgrade(self.cfg, "%s:base" % self.rid, sql=True) + assert "DROP TABLE sometable" in buf.getvalue() + # no drop since we didn't emit events + assert "DROP TYPE pgenum" not in buf.getvalue() + + def test_offline_distinct_enum_create(self): + self._distinct_enum_script() + with capture_context_buffer() as buf: + command.upgrade(self.cfg, self.rid, sql=True) + assert "CREATE TYPE pgenum AS ENUM ('one','two','three')" in buf.getvalue() + assert "CREATE TABLE sometable (\n data pgenum\n)" in buf.getvalue() + + def test_offline_distinct_enum_drop(self): + self._distinct_enum_script() + with capture_context_buffer() as buf: + command.downgrade(self.cfg, "%s:base" % self.rid, sql=True) + assert "DROP TABLE sometable" in buf.getvalue() + assert "DROP TYPE pgenum" in buf.getvalue() + + class PostgresqlDefaultCompareTest(TestCase): @classmethod @@ -49,14 +135,6 @@ class PostgresqlDefaultCompareTest(TestCase): assert self._compare_default( t, t2, t2.c.somecol, alternate ) is expected -# t.create(self.bind) -# insp = Inspector.from_engine(self.bind) -# cols = insp.get_columns("test") -# ctx = context.get_context() -# assert ctx.impl.compare_server_default( -# cols[0], -# t2.c.somecol, -# alternate) is expected def _compare_default( self, diff --git a/tests/test_sql_script.py b/tests/test_sql_script.py index 1c3df484..1df94cdc 100644 --- a/tests/test_sql_script.py +++ b/tests/test_sql_script.py @@ -1,4 +1,6 @@ -from tests import clear_staging_env, staging_env, no_sql_testing_config, sqlite_db, eq_, ne_, capture_context_buffer, three_rev_fixture +from tests import clear_staging_env, staging_env, \ + no_sql_testing_config, sqlite_db, eq_, ne_, capture_context_buffer, \ + three_rev_fixture from alembic import command, util def setup(): -- 2.47.2