def get_all_current(self, id_: Tuple[str, ...]) -> Set[Optional[Script]]:
with self._catch_revision_errors():
- top_revs = cast(
- Set[Optional[Script]],
- set(self.revision_map.get_revisions(id_)),
- )
- top_revs.update(
- cast(
- Iterator[Script],
- self.revision_map._get_ancestor_nodes(
- list(top_revs), include_dependencies=True
- ),
- )
+ return cast(
+ Set[Optional[Script]], self.revision_map._get_all_current(id_)
)
- top_revs = self.revision_map._filter_into_branch_heads(top_revs)
- return top_revs
def get_revision(self, id_: str) -> Optional[Script]:
"""Return the :class:`.Script` instance with the given rev id.
if TYPE_CHECKING:
from typing import Literal
- from .base import Script
-
_RevIdType = Union[str, Sequence[str]]
_RevisionIdentifierType = Union[str, Tuple[str, ...], None]
_RevisionOrStr = Union["Revision", str]
return revision
def _filter_into_branch_heads(
- self, targets: Set[Optional[Script]]
- ) -> Set[Optional[Script]]:
+ self, targets: Iterable[Optional[_RevisionOrBase]]
+ ) -> Set[Optional[_RevisionOrBase]]:
targets = set(targets)
for rev in list(targets):
def _get_descendant_nodes(
self,
- targets: Collection[Revision],
+ targets: Collection[Optional[_RevisionOrBase]],
map_: Optional[_RevisionMapType] = None,
check: bool = False,
omit_immediate_dependencies: bool = False,
if relative_revision:
# Find target revision relative to current state.
if branch_label:
+ cr_tuple = util.to_tuple(current_revisions)
+ symbol_list: Sequence[str]
symbol_list = self.filter_for_lineage(
- util.to_tuple(current_revisions), branch_label
+ cr_tuple, branch_label
)
+ if not symbol_list:
+ # check the case where there are multiple branches
+ # but there is currently a single heads, since all
+ # other branch heads are dependant of the current
+ # single heads.
+ all_current = cast(
+ Set[Revision], self._get_all_current(cr_tuple)
+ )
+ sl_all_current = self.filter_for_lineage(
+ all_current, branch_label
+ )
+ symbol_list = [
+ r.revision if r else r # type: ignore[misc]
+ for r in sl_all_current
+ ]
+
assert len(symbol_list) == 1
symbol = symbol_list[0]
else:
return needs, tuple(targets) # type:ignore[return-value]
+ def _get_all_current(
+ self, id_: Tuple[str, ...]
+ ) -> Set[Optional[_RevisionOrBase]]:
+ top_revs: Set[Optional[_RevisionOrBase]]
+ top_revs = set(self.get_revisions(id_))
+ top_revs.update(
+ self._get_ancestor_nodes(list(top_revs), include_dependencies=True)
+ )
+ return self._filter_into_branch_heads(top_revs)
+
class Revision:
"""Base class for revisioned objects.
self,
revision: str,
down_revision: Optional[Union[str, Tuple[str, ...]]],
- dependencies: Optional[Tuple[str, ...]] = None,
- branch_labels: Optional[Tuple[str, ...]] = None,
+ dependencies: Optional[Union[str, Tuple[str, ...]]] = None,
+ branch_labels: Optional[Union[str, Tuple[str, ...]]] = None,
) -> None:
if down_revision and revision in util.to_tuple(down_revision):
raise LoopDetected(revision)
import io
import os
import sys
-from typing import Tuple
+from typing import Sequence
from sqlalchemy.util import inspect_getfullargspec # noqa
from sqlalchemy.util.compat import inspect_formatargspec # noqa
from importlib_metadata import EntryPoint # type:ignore # noqa
-def importlib_metadata_get(group: str) -> Tuple[EntryPoint, ...]:
+def importlib_metadata_get(group: str) -> Sequence[EntryPoint]:
ep = importlib_metadata.entry_points()
if hasattr(ep, "select"):
return ep.select(group=group) # type:ignore[attr-defined]
edges, list(reversed(result))
)
- eq_(
- result,
- assertion,
- )
+ eq_(result, assertion)
class DiamondTest(DownIterateTest):
)
+class MultipleBranchEffectiveHead(DownIterateTest):
+ def setUp(self):
+ self.map = RevisionMap(
+ lambda: [
+ Revision("y1", None, branch_labels="y"),
+ Revision("x1", None, branch_labels="x"),
+ Revision("y2", "y1", dependencies="x1"),
+ Revision("x2", "x1"),
+ ]
+ )
+
+ def test_other_downgrade(self):
+ self._assert_iteration(
+ ("x2", "y2"),
+ "x@-1",
+ ["x2"],
+ inclusive=False,
+ select_for_downgrade=True,
+ )
+
+ def test_use_all_current(self):
+ self._assert_iteration(
+ ("x1", "y2"),
+ "x@-1",
+ ["y2", "x1"],
+ inclusive=False,
+ select_for_downgrade=True,
+ )
+
+ def test_effective_head(self):
+ self._assert_iteration(
+ "y2",
+ "x@-1",
+ ["y2", "x1"],
+ inclusive=False,
+ select_for_downgrade=True,
+ )
+
+
class BranchTravellingTest(DownIterateTest):
"""test the order of revs when going along multiple branches.