From 3c93a7ea8ac63fbad94f35950c497a1353616e5f Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 24 Apr 2019 13:46:03 -0500 Subject: [PATCH] Raise for non-string revision identifier Added an assertion in :meth:`.RevisionMap.get_revisions` and other methods which ensures revision numbers are passed as strings or collections of strings. Driver issues particularly on MySQL may inadvertently be passing bytes here which leads to failures later on. Change-Id: Id80c958f0a082fed26cac2cf838cb7507b8d814c Fixes: #551 --- alembic/script/revision.py | 15 +++++++++++++++ alembic/testing/requirements.py | 8 ++++++++ docs/build/unreleased/551.rst | 8 ++++++++ tests/test_revision.py | 34 +++++++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+) create mode 100644 docs/build/unreleased/551.rst diff --git a/alembic/script/revision.py b/alembic/script/revision.py index af08688d..43c757ef 100644 --- a/alembic/script/revision.py +++ b/alembic/script/revision.py @@ -314,6 +314,7 @@ class RevisionMap(object): full revision. """ + if isinstance(id_, (list, tuple, set, frozenset)): return sum([self.get_revisions(id_elem) for id_elem in id_], ()) else: @@ -469,6 +470,20 @@ class RevisionMap(object): def _resolve_revision_number(self, id_): if isinstance(id_, compat.string_types) and "@" in id_: branch_label, id_ = id_.split("@", 1) + + elif id_ is not None and ( + ( + isinstance(id_, tuple) + and id_ + and not isinstance(id_[0], compat.string_types) + ) + or not isinstance(id_, compat.string_types + (tuple, )) + ): + raise RevisionError( + "revision identifier %r is not a string; ensure database " + "driver settings are correct" % (id_,) + ) + else: branch_label = None diff --git a/alembic/testing/requirements.py b/alembic/testing/requirements.py index 24f86672..4e88a9b0 100644 --- a/alembic/testing/requirements.py +++ b/alembic/testing/requirements.py @@ -1,3 +1,5 @@ +import sys + from alembic import util from alembic.util import sqla_compat from . import exclusions @@ -153,6 +155,12 @@ class SuiteRequirements(Requirements): "SQLAlchemy 1.2.16, 1.3.0b2 or greater required", ) + @property + def python3(self): + return exclusions.skip_if( + lambda: sys.version_info < (3,), "Python version 3.xx is required." + ) + @property def pep3147(self): diff --git a/docs/build/unreleased/551.rst b/docs/build/unreleased/551.rst new file mode 100644 index 00000000..043a1623 --- /dev/null +++ b/docs/build/unreleased/551.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, operations, mysql + :tickets: 551 + + Added an assertion in :meth:`.RevisionMap.get_revisions` and other methods + which ensures revision numbers are passed as strings or collections of + strings. Driver issues particularly on MySQL may inadvertently be passing + bytes here which leads to failures later on. \ No newline at end of file diff --git a/tests/test_revision.py b/tests/test_revision.py index 20eb309a..41d8b42a 100644 --- a/tests/test_revision.py +++ b/tests/test_revision.py @@ -3,12 +3,46 @@ from alembic.script.revision import Revision from alembic.script.revision import RevisionError from alembic.script.revision import RevisionMap from alembic.testing import assert_raises_message +from alembic.testing import config from alembic.testing import eq_ from alembic.testing.fixtures import TestBase from . import _large_map class APITest(TestBase): + @config.requirements.python3 + def test_invalid_datatype(self): + map_ = RevisionMap( + lambda: [ + Revision("a", ()), + Revision("b", ("a",)), + Revision("c", ("b",)), + ] + ) + assert_raises_message( + RevisionError, + "revision identifier b'12345' is not a string; " + "ensure database driver settings are correct", + map_.get_revisions, b'12345' + ) + + assert_raises_message( + RevisionError, + "revision identifier b'12345' is not a string; " + "ensure database driver settings are correct", + map_.get_revision, b'12345' + ) + + assert_raises_message( + RevisionError, + r"revision identifier \(b'12345',\) is not a string; " + "ensure database driver settings are correct", + map_.get_revision, (b'12345', ) + ) + + map_.get_revision(("a", )) + map_.get_revision("a") + def test_add_revision_one_head(self): map_ = RevisionMap( lambda: [ -- 2.47.2