]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
add recursive_version_locations option for searching revision files
authorostr00000 <ostr00000@gmail.com>
Mon, 27 Feb 2023 23:18:19 +0000 (18:18 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 1 Mar 2023 00:04:59 +0000 (19:04 -0500)
Recursive traversal of revision files in a particular revision directory is
now supported, by indicating ``recursive_version_locations = true`` in
alembic.ini. Pull request courtesy ostr00000.

Fixes: #760
Closes: #1182
Pull-request: https://github.com/sqlalchemy/alembic/pull/1182
Pull-request-sha: ecb0da48b459abd3f5e95390ec7030a7e3fcbc6d

Change-Id: I711ca2dbd35fb9a2acdbfd374bcac13043b0d129

alembic/script/base.py
alembic/templates/async/alembic.ini.mako
alembic/templates/generic/alembic.ini.mako
alembic/templates/multidb/alembic.ini.mako
docs/build/tutorial.rst
docs/build/unreleased/760.rst [new file with mode: 0644]
tests/test_script_consumption.py
tests/test_script_production.py

index 3c09cef7d165dcca4c8fdeabc4f42a718f2880ea..b6858b59acb13471960428c1cbfc59147f8cb8d9 100644 (file)
@@ -80,6 +80,7 @@ class ScriptDirectory:
         output_encoding: str = "utf-8",
         timezone: Optional[str] = None,
         hook_config: Optional[Dict[str, str]] = None,
+        recursive_version_locations: bool = False,
     ) -> None:
         self.dir = dir
         self.file_template = file_template
@@ -90,6 +91,7 @@ class ScriptDirectory:
         self.revision_map = revision.RevisionMap(self._load_revisions)
         self.timezone = timezone
         self.hook_config = hook_config
+        self.recursive_version_locations = recursive_version_locations
 
         if not os.access(dir, os.F_OK):
             raise util.CommandError(
@@ -128,16 +130,19 @@ class ScriptDirectory:
 
         dupes = set()
         for vers in paths:
-            for file_ in Script._list_py_dir(self, vers):
-                path = os.path.realpath(os.path.join(vers, file_))
-                if path in dupes:
+            for file_path in Script._list_py_dir(self, vers):
+                real_path = os.path.realpath(file_path)
+                if real_path in dupes:
                     util.warn(
                         "File %s loaded twice! ignoring. Please ensure "
-                        "version_locations is unique." % path
+                        "version_locations is unique." % real_path
                     )
                     continue
-                dupes.add(path)
-                script = Script._from_filename(self, vers, file_)
+                dupes.add(real_path)
+
+                filename = os.path.basename(real_path)
+                dir_name = os.path.dirname(real_path)
+                script = Script._from_filename(self, dir_name, filename)
                 if script is None:
                     continue
                 yield script
@@ -207,6 +212,7 @@ class ScriptDirectory:
                 _split_on_space_comma_colon.split(prepend_sys_path)
             )
 
+        rvl = config.get_main_option("recursive_version_locations") == "true"
         return ScriptDirectory(
             util.coerce_resource_to_filename(script_location),
             file_template=config.get_main_option(
@@ -218,6 +224,7 @@ class ScriptDirectory:
             version_locations=version_locations,
             timezone=config.get_main_option("timezone"),
             hook_config=config.get_section("post_write_hooks", {}),
+            recursive_version_locations=rvl,
         )
 
     @contextmanager
@@ -959,26 +966,40 @@ class Script(revision.Revision):
 
     @classmethod
     def _list_py_dir(cls, scriptdir: ScriptDirectory, path: str) -> List[str]:
-        if scriptdir.sourceless:
-            # read files in version path, e.g. pyc or pyo files
-            # in the immediate path
-            paths = os.listdir(path)
-
-            names = {fname.split(".")[0] for fname in paths}
-
-            # look for __pycache__
-            if os.path.exists(os.path.join(path, "__pycache__")):
-                # add all files from __pycache__ whose filename is not
-                # already in the names we got from the version directory.
-                # add as relative paths including __pycache__ token
-                paths.extend(
-                    os.path.join("__pycache__", pyc)
-                    for pyc in os.listdir(os.path.join(path, "__pycache__"))
-                    if pyc.split(".")[0] not in names
-                )
-            return paths
-        else:
-            return os.listdir(path)
+        paths = []
+        for root, dirs, files in os.walk(path, topdown=True):
+            if root.endswith("__pycache__"):
+                # a special case - we may include these files
+                # if a `sourceless` option is specified
+                continue
+
+            for filename in sorted(files):
+                paths.append(os.path.join(root, filename))
+
+            if scriptdir.sourceless:
+                # look for __pycache__
+                py_cache_path = os.path.join(root, "__pycache__")
+                if os.path.exists(py_cache_path):
+                    # add all files from __pycache__ whose filename is not
+                    # already in the names we got from the version directory.
+                    # add as relative paths including __pycache__ token
+                    names = {filename.split(".")[0] for filename in files}
+                    paths.extend(
+                        os.path.join(py_cache_path, pyc)
+                        for pyc in os.listdir(py_cache_path)
+                        if pyc.split(".")[0] not in names
+                    )
+
+            if not scriptdir.recursive_version_locations:
+                break
+
+            # the real script order is defined by revision,
+            # but it may be undefined if there are many files with a same
+            # `down_revision`, for a better user experience (ex. debugging),
+            # we use a deterministic order
+            dirs.sort()
+
+        return paths
 
     @classmethod
     def _from_filename(
index 5268e7cd7184d7ccb49fcfd31248a21b144d50ff..64c7b6b97d4949afd867f77549c30cb675ff865c 100644 (file)
@@ -49,6 +49,11 @@ prepend_sys_path = .
 # version_path_separator = space
 version_path_separator = os  # Use os.pathsep. Default configuration used for new projects.
 
+# set to 'true' to search source files recursively
+# in each "version_locations" directory
+# new in Alembic version 1.10
+# recursive_version_locations = false
+
 # the output encoding used when revision files
 # are written from script.py.mako
 # output_encoding = utf-8
index 8aa47b19b5d15c176bda11bd296903b8e672581a..f541b179a0d542334d4baca0754bcb88c9214be0 100644 (file)
@@ -51,6 +51,11 @@ prepend_sys_path = .
 # version_path_separator = space
 version_path_separator = os  # Use os.pathsep. Default configuration used for new projects.
 
+# set to 'true' to search source files recursively
+# in each "version_locations" directory
+# new in Alembic version 1.10
+# recursive_version_locations = false
+
 # the output encoding used when revision files
 # are written from script.py.mako
 # output_encoding = utf-8
index 5adef392fdee3cfd3327e7d4888ac4a8ded86970..4230fe1357917bbb847c6567934f5f7d3fb26a2f 100644 (file)
@@ -51,6 +51,11 @@ prepend_sys_path = .
 # version_path_separator = space
 version_path_separator = os  # Use os.pathsep. Default configuration used for new projects.
 
+# set to 'true' to search source files recursively
+# in each "version_locations" directory
+# new in Alembic version 1.10
+# recursive_version_locations = false
+
 # the output encoding used when revision files
 # are written from script.py.mako
 # output_encoding = utf-8
index 8540be13d92349a33cd993946ab76431bef245c9..8823a91bfddebbae0090a9461b293dc0bbe344e3 100644 (file)
@@ -177,6 +177,11 @@ The file generated with the "generic" configuration looks like::
     # version_path_separator = space
     version_path_separator = os  # Use os.pathsep. Default configuration used for new projects.
 
+    # set to 'true' to search source files recursively
+    # in each "version_locations" directory
+    # new in Alembic version 1.10
+    # recursive_version_locations = false
+
     # the output encoding used when revision files
     # are written from script.py.mako
     # output_encoding = utf-8
@@ -332,6 +337,11 @@ This file contains the following features:
   It should be defined if multiple ``version_locations`` is used.
   See :ref:`multiple_bases` for examples.
 
+* ``recursive_version_locations`` - when set to 'true', revision files
+  are searched recursively in each "version_locations" directory.
+
+  .. versionadded:: 1.10
+
 * ``output_encoding`` - the encoding to use when Alembic writes the
   ``script.py.mako`` file into a new migration file.  Defaults to ``'utf-8'``.
 
diff --git a/docs/build/unreleased/760.rst b/docs/build/unreleased/760.rst
new file mode 100644 (file)
index 0000000..5f46e10
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: feature, revisioning
+    :tickets: 760
+
+    Recursive traversal of revision files in a particular revision directory is
+    now supported, by indicating ``recursive_version_locations = true`` in
+    alembic.ini. Pull request courtesy ostr00000.
+
index fa84d7e3789f29f77b22ba74ef4527ba1d9f9959..a107b8055a1a4415e60c9e5fde4174b38bcc5c78 100644 (file)
@@ -1,7 +1,12 @@
+from __future__ import annotations
+
 from contextlib import contextmanager
 import os
 import re
+import shutil
 import textwrap
+from typing import Dict
+from typing import List
 
 import sqlalchemy as sa
 from sqlalchemy import pool
@@ -9,18 +14,24 @@ from sqlalchemy import pool
 from alembic import command
 from alembic import testing
 from alembic import util
+from alembic.config import Config
 from alembic.environment import EnvironmentContext
 from alembic.script import Script
 from alembic.script import ScriptDirectory
 from alembic.testing import assert_raises_message
+from alembic.testing import assertions
 from alembic.testing import config
 from alembic.testing import eq_
+from alembic.testing import expect_raises_message
 from alembic.testing import mock
+from alembic.testing.env import _get_staging_directory
+from alembic.testing.env import _multi_dir_testing_config
 from alembic.testing.env import _no_sql_testing_config
 from alembic.testing.env import _sqlite_file_db
 from alembic.testing.env import _sqlite_testing_config
 from alembic.testing.env import clear_staging_env
 from alembic.testing.env import env_file_fixture
+from alembic.testing.env import multi_heads_fixture
 from alembic.testing.env import staging_env
 from alembic.testing.env import three_rev_fixture
 from alembic.testing.env import write_script
@@ -900,3 +911,325 @@ class SourcelessNeedsFlagTest(TestBase):
         self.cfg.set_main_option("sourceless", "true")
         script = ScriptDirectory.from_config(self.cfg)
         eq_(script.get_heads(), [a])
+
+
+class RecursiveScriptDirectoryTest(TestBase):
+    """test recursive version directory consumption for #760"""
+
+    rev: List[str]
+    org_script_dir: ScriptDirectory
+    cfg: Config
+    _script_by_name: Dict[str, Script]
+    _name_by_revision: Dict[str, str]
+
+    def _setup_revision_files(
+        self, listing, destination=".", version_path="scripts/versions"
+    ):
+        for elem in listing:
+            if isinstance(elem, str):
+                if destination != ".":
+                    script = self._script_by_name[elem]
+                    target_file = self._get_moved_path(
+                        elem, destination, version_path
+                    )
+                    os.makedirs(os.path.dirname(target_file), exist_ok=True)
+                    shutil.move(script.path, target_file)
+            else:
+                dest, files = elem
+                if dest == "delete":
+                    for fname in files:
+                        revision_to_remove = self._script_by_name[fname]
+                        os.remove(revision_to_remove.path)
+                else:
+                    self._setup_revision_files(
+                        files, os.path.join(destination, dest), version_path
+                    )
+
+    def _get_moved_path(
+        self,
+        elem: str,
+        destination_dir: str = "",
+        version_path="scripts/versions",
+    ):
+        script = self._script_by_name[elem]
+        file_name = os.path.basename(script.path)
+        target_file = os.path.join(
+            _get_staging_directory(), version_path, destination_dir, file_name
+        )
+        target_file = os.path.realpath(target_file)
+        return target_file
+
+    def _assert_setup(self, *elements):
+        sd = ScriptDirectory.from_config(self.cfg)
+
+        _new_rev_to_script = {
+            self._name_by_revision[r.revision]: r for r in sd.walk_revisions()
+        }
+
+        for revname, directory, version_path in elements:
+            eq_(
+                _new_rev_to_script[revname].path,
+                self._get_moved_path(revname, directory, version_path),
+            )
+
+        eq_(len(_new_rev_to_script), len(elements))
+
+        revs_to_check = {
+            self._script_by_name[rev].revision for rev, _, _ in elements
+        }
+
+        # topological order check
+        for rev_id in revs_to_check:
+            new_script = sd.get_revision(rev_id)
+            assertions.is_not_(new_script, None)
+
+            old_revisions = {
+                r.revision: r
+                for r in self.org_script_dir.revision_map.iterate_revisions(
+                    rev_id,
+                    "base",
+                    inclusive=True,
+                    assert_relative_length=False,
+                )
+            }
+            new_revisions = {
+                r.revision: r
+                for r in sd.revision_map.iterate_revisions(
+                    rev_id,
+                    "base",
+                    inclusive=True,
+                    assert_relative_length=False,
+                )
+            }
+
+            eq_(len(old_revisions), len(new_revisions))
+
+            for common_rev_id in set(old_revisions.keys()).union(
+                new_revisions.keys()
+            ):
+                old_rev = old_revisions[common_rev_id]
+                new_rev = new_revisions[common_rev_id]
+
+                eq_(old_rev.revision, new_rev.revision)
+                eq_(old_rev.down_revision, new_rev.down_revision)
+                eq_(old_rev.dependencies, new_rev.dependencies)
+
+    def _setup_for_fixture(self, revs):
+        self.rev = revs
+
+        self.org_script_dir = ScriptDirectory.from_config(self.cfg)
+        rev_to_script = {
+            script.revision: script
+            for script in self.org_script_dir.walk_revisions()
+        }
+        self._script_by_name = {
+            f"r{i}": rev_to_script[revnum] for i, revnum in enumerate(self.rev)
+        }
+        self._name_by_revision = {
+            v.revision: k for k, v in self._script_by_name.items()
+        }
+
+    @testing.fixture
+    def non_recursive_fixture(self):
+        self.env = staging_env()
+        self.cfg = _sqlite_testing_config()
+
+        ids = [util.rev_id() for i in range(5)]
+
+        script = ScriptDirectory.from_config(self.cfg)
+        script.generate_revision(
+            ids[0], "revision a", refresh=True, head="base"
+        )
+        script.generate_revision(
+            ids[1], "revision b", refresh=True, head=ids[0]
+        )
+        script.generate_revision(
+            ids[2], "revision c", refresh=True, head=ids[1]
+        )
+        script.generate_revision(
+            ids[3], "revision d", refresh=True, head="base"
+        )
+        script.generate_revision(
+            ids[4], "revision e", refresh=True, head=ids[3]
+        )
+
+        self._setup_for_fixture(ids)
+
+        yield
+
+        clear_staging_env()
+
+    @testing.fixture
+    def single_base_fixture(self):
+        self.env = staging_env()
+        self.cfg = _sqlite_testing_config()
+        self.cfg.set_main_option("recursive_version_locations", "true")
+
+        revs = list(three_rev_fixture(self.cfg))
+        revs.extend(multi_heads_fixture(self.cfg, *revs[0:3]))
+
+        self._setup_for_fixture(revs)
+
+        yield
+
+        clear_staging_env()
+
+    @testing.fixture
+    def multi_base_fixture(self):
+
+        self.env = staging_env()
+        self.cfg = _multi_dir_testing_config()
+        self.cfg.set_main_option("recursive_version_locations", "true")
+
+        script0 = command.revision(
+            self.cfg,
+            message="x",
+            head="base",
+            version_path=os.path.join(_get_staging_directory(), "model1"),
+        )
+        assert isinstance(script0, Script)
+        script1 = command.revision(
+            self.cfg,
+            message="y",
+            head="base",
+            version_path=os.path.join(_get_staging_directory(), "model2"),
+        )
+        assert isinstance(script1, Script)
+        script2 = command.revision(
+            self.cfg, message="y2", head=script1.revision
+        )
+        assert isinstance(script2, Script)
+
+        self.org_script_dir = ScriptDirectory.from_config(self.cfg)
+
+        rev_to_script = {
+            script0.revision: script0,
+            script1.revision: script1,
+            script2.revision: script2,
+        }
+
+        self._setup_for_fixture(rev_to_script)
+
+        yield
+
+        clear_staging_env()
+
+    def test_ignore_for_non_recursive(self, non_recursive_fixture):
+        """test traversal is non-recursive when the feature is not enabled
+        (subdirectories are ignored).
+
+        """
+
+        self._setup_revision_files(
+            [
+                "r0",
+                "r1",
+                ("dir_1", ["r2", "r3"]),
+                ("dir_2", ["r4"]),
+            ]
+        )
+
+        vl = "scripts/versions"
+
+        self._assert_setup(
+            ("r0", "", vl),
+            ("r1", "", vl),
+        )
+
+    def test_flat_structure(self, single_base_fixture):
+        assert len(self.rev) == 6
+
+    def test_flat_and_dir_structure(self, single_base_fixture):
+        self._setup_revision_files(
+            [
+                "r1",
+                ("dir_1", ["r0", "r2"]),
+                ("dir_2", ["r4"]),
+                ("dir_3", ["r5"]),
+            ]
+        )
+
+        vl = "scripts/versions"
+
+        self._assert_setup(
+            ("r0", "dir_1", vl),
+            ("r1", "", vl),
+            ("r2", "dir_1", vl),
+            ("r3", "", vl),
+            ("r4", "dir_2", vl),
+            ("r5", "dir_3", vl),
+        )
+
+    def test_nested_dir_structure(self, single_base_fixture):
+        self._setup_revision_files(
+            [
+                (
+                    "dir_1",
+                    ["r0", ("nested_1", ["r1", "r2"]), ("nested_2", ["r3"])],
+                ),
+                ("dir_2", ["r4"]),
+                ("dir_3", [("nested_3", ["r5"])]),
+            ]
+        )
+
+        vl = "scripts/versions"
+
+        self._assert_setup(
+            ("r0", "dir_1", vl),
+            ("r1", "dir_1/nested_1", vl),
+            ("r2", "dir_1/nested_1", vl),
+            ("r3", "dir_1/nested_2", vl),
+            ("r4", "dir_2", vl),
+            ("r5", "dir_3/nested_3", vl),
+        )
+
+    def test_dir_structure_with_missing_file(self, single_base_fixture):
+        sd = ScriptDirectory.from_config(self.cfg)
+
+        revision_to_remove = self._script_by_name["r1"]
+        self._setup_revision_files(
+            [
+                ("delete", ["r1"]),
+                ("dir_1", ["r0", "r2"]),
+                ("dir_2", ["r4"]),
+                ("dir_3", ["r5"]),
+            ]
+        )
+
+        with expect_raises_message(KeyError, revision_to_remove.revision):
+            list(sd.walk_revisions())
+
+    def test_multiple_dir_recursive(self, multi_base_fixture):
+        self._setup_revision_files(
+            [
+                ("dir_0", ["r0"]),
+            ],
+            version_path="model1",
+        )
+        self._setup_revision_files(
+            [
+                ("dir_1", ["r1", ("nested", ["r2"])]),
+            ],
+            version_path="model2",
+        )
+        self._assert_setup(
+            ("r0", "dir_0", "model1"),
+            ("r1", "dir_1", "model2"),
+            ("r2", "dir_1/nested", "model2"),
+        )
+
+    def test_multiple_dir_recursive_change_version_dir(
+        self, multi_base_fixture
+    ):
+        self._setup_revision_files(
+            [
+                ("dir_0", ["r0"]),
+                ("dir_1", ["r1", ("nested", ["r2"])]),
+            ],
+            version_path="model1",
+        )
+        self._assert_setup(
+            ("r0", "dir_0", "model1"),
+            ("r1", "dir_1", "model1"),
+            ("r2", "dir_1/nested", "model1"),
+        )
index bedf545d9d2f9b41c865ccbf709164a8e60cf8ef..bddea5fc7aa252ee608a67a77b244d9019861e8e 100644 (file)
@@ -707,7 +707,7 @@ class ImportsTest(TestBase):
                 context.configure(
                     connection=connection,
                     target_metadata=target_metadata,
-                    **kw
+                    **kw,
                 )
                 with context.begin_transaction():
                     context.run_migrations()