]> git.ipfire.org Git - thirdparty/mkosi.git/commitdiff
Reimplement GenericVersion without shelling out to systemd-analyze
authorJoerg Behrmann <behrmann@physik.fu-berlin.de>
Fri, 7 Jul 2023 11:49:50 +0000 (13:49 +0200)
committerJoerg Behrmann <behrmann@physik.fu-berlin.de>
Wed, 12 Jul 2023 08:47:28 +0000 (10:47 +0200)
mkosi/config.py
tests/test_config.py

index 2bc181f3745ce8ca84620722a3f683abd83f2341..cb1cbbf509c9473a92048d1bd5715ea64ca19ef3 100644 (file)
@@ -19,6 +19,7 @@ import subprocess
 import sys
 import textwrap
 from collections.abc import Sequence
+from itertools import takewhile
 from pathlib import Path
 from typing import Any, Callable, Optional, Type, Union, cast
 
@@ -1809,44 +1810,151 @@ class MkosiConfigParser:
 
 
 class GenericVersion:
+    # These constants follow the convention of the return value of rpmdev-vercmp that are followe
+    # by systemd-analyze compare-versions when called with only two arguments (without a comparison
+    # operator), recreated in the compare_versions method.
+    _EQUAL = 0
+    _RIGHT_SMALLER = 11
+    _LEFT_SMALLER = 12
+
     def __init__(self, version: str):
         self._version = version
 
+    @classmethod
+    def compare_versions(cls, v1: str, v2: str) -> int:
+        """Implements comparison according to UAPI Group Version Format Specification"""
+        def rstrip_invalid_version_chars(s: str) -> str:
+            valid_version_chars = {*string.ascii_letters, *string.digits, "~", "-", "^", "."}
+            for i, c in enumerate(s):
+                if c in valid_version_chars:
+                    return s[i:]
+            return ""
+
+        def digit_prefix(s: str) -> str:
+            return "".join(takewhile(lambda c: c in string.digits, s))
+
+        def letter_prefix(s: str) -> str:
+            return "".join(takewhile(lambda c: c in string.ascii_letters, s))
+
+        while True:
+            #breakpoint()
+            # Any characters which are outside of the set of listed above (a-z, A-Z, 0-9, -, ., ~,
+            # ^) are skipped in both strings. In particular, this means that non-ASCII characters
+            # that are Unicode digits or letters are skipped too.
+            v1 = rstrip_invalid_version_chars(v1)
+            v2 = rstrip_invalid_version_chars(v2)
+            # If the remaining part of one of strings starts with "~": if other remaining part does
+            # not start with ~, the string with ~ compares lower. Otherwise, both tilde characters
+            # are skipped.
+            if v1.startswith("~") and v2.startswith("~"):
+                v1 = v1.removeprefix("~")
+                v2 = v2.removeprefix("~")
+            elif v1.startswith("~"):
+                return cls._LEFT_SMALLER
+            elif v2.startswith("~"):
+                return cls._RIGHT_SMALLER
+            # If one of the strings has ended: if the other string hasn’t, the string that has
+            # remaining characters compares higher. Otherwise, the strings compare equal.
+            if not v1 and not v2:
+                return cls._EQUAL
+            elif not v1 and v2:
+                return cls._LEFT_SMALLER
+            elif v1 and not v2:
+                return cls._RIGHT_SMALLER
+            # If the remaining part of one of strings starts with "-": if the other remaining part
+            # does not start with -, the string with - compares lower. Otherwise, both minus
+            # characters are skipped.
+            if v1.startswith("-") and v2.startswith("-"):
+                v1 = v1.removeprefix("-")
+                v2 = v2.removeprefix("-")
+            elif v1.startswith("-"):
+                return cls._LEFT_SMALLER
+            elif v2.startswith("-"):
+                return cls._RIGHT_SMALLER
+            # If the remaining part of one of strings starts with "^": if the other remaining part
+            # does not start with ^, the string with ^ compares higher. Otherwise, both caret
+            # characters are skipped.
+            if v1.startswith("^") and v2.startswith("^"):
+                v1 = v1.removeprefix("^")
+                v2 = v2.removeprefix("^")
+            elif v1.startswith("^"):
+                # TODO: bug?
+                return cls._LEFT_SMALLER  #cls._RIGHT_SMALLER
+            elif v2.startswith("^"):
+                return cls._RIGHT_SMALLER #cls._LEFT_SMALLER
+            # If the remaining part of one of strings starts with ".": if the other remaining part
+            # does not start with ., the string with . compares lower. Otherwise, both dot
+            # characters are skipped.
+            if v1.startswith(".") and v2.startswith("."):
+                v1 = v1.removeprefix(".")
+                v2 = v2.removeprefix(".")
+            elif v1.startswith("."):
+                return cls._LEFT_SMALLER
+            elif v2.startswith("."):
+                return cls._RIGHT_SMALLER
+            # If either of the remaining parts starts with a digit: numerical prefixes are compared
+            # numerically. Any leading zeroes are skipped. The numerical prefixes (until the first
+            # non-digit character) are evaluated as numbers. If one of the prefixes is empty, it
+            # evaluates as 0. If the numbers are different, the string with the bigger number
+            # compares higher. Otherwise, the comparison continues at the following characters at
+            # point 1.
+            v1_digit_prefix = digit_prefix(v1)
+            v2_digit_prefix = digit_prefix(v2)
+            if v1_digit_prefix or v2_digit_prefix:
+                v1_digits = int(v1_digit_prefix) if v1_digit_prefix else 0
+                v2_digits = int(v2_digit_prefix) if v2_digit_prefix else 0
+                if v1_digits < v2_digits:
+                    return cls._LEFT_SMALLER
+                elif v1_digits > v2_digits:
+                    return cls._RIGHT_SMALLER
+
+                v1 = v1.removeprefix(v1_digit_prefix)
+                v2 = v2.removeprefix(v2_digit_prefix)
+                continue
+            # Leading alphabetical prefixes are compared alphabetically. The substrings are
+            # compared letter-by-letter. If both letters are the same, the comparison continues
+            # with the next letter. Capital letters compare lower than lower-case letters (A <
+            # a). When the end of one substring has been reached (a non-letter character or the end
+            # of the whole string), if the other substring has remaining letters, it compares
+            # higher. Otherwise, the comparison continues at the following characters at point 1.
+            v1_letter_prefix = letter_prefix(v1)
+            v2_letter_prefix = letter_prefix(v2)
+            if v1_letter_prefix < v2_letter_prefix:
+                return cls._LEFT_SMALLER
+            elif v1_letter_prefix > v2_letter_prefix:
+                return cls._RIGHT_SMALLER
+            v1 = v1.removeprefix(v1_letter_prefix)
+            v2 = v2.removeprefix(v2_letter_prefix)
+
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, GenericVersion):
             return False
-        cmd = ["systemd-analyze", "compare-versions", self._version, "eq", other._version]
-        return run(cmd, check=False).returncode == 0
+        return self.compare_versions(self._version, other._version) == self._EQUAL
 
     def __ne__(self, other: object) -> bool:
         if not isinstance(other, GenericVersion):
             return False
-        cmd = ["systemd-analyze", "compare-versions", self._version, "ne", other._version]
-        return run(cmd, check=False).returncode == 0
+        return self.compare_versions(self._version, other._version) != self._EQUAL
 
     def __lt__(self, other: object) -> bool:
         if not isinstance(other, GenericVersion):
             return False
-        cmd = ["systemd-analyze", "compare-versions", self._version, "lt", other._version]
-        return run(cmd, check=False).returncode == 0
+        return self.compare_versions(self._version, other._version) == self._LEFT_SMALLER
 
     def __le__(self, other: object) -> bool:
         if not isinstance(other, GenericVersion):
             return False
-        cmd = ["systemd-analyze", "compare-versions", self._version, "le", other._version]
-        return run(cmd, check=False).returncode == 0
+        return self.compare_versions(self._version, other._version) in (self._EQUAL, self._LEFT_SMALLER)
 
     def __gt__(self, other: object) -> bool:
         if not isinstance(other, GenericVersion):
             return False
-        cmd = ["systemd-analyze", "compare-versions", self._version, "gt", other._version]
-        return run(cmd, check=False).returncode == 0
+        return self.compare_versions(self._version, other._version) == self._RIGHT_SMALLER
 
     def __ge__(self, other: object) -> bool:
         if not isinstance(other, GenericVersion):
             return False
-        cmd = ["systemd-analyze", "compare-versions", self._version, "ge", other._version]
-        return run(cmd, check=False).returncode == 0
+        return self.compare_versions(self._version, other._version) in (self._EQUAL, self._RIGHT_SMALLER)
 
 
 def find_image_version(args: argparse.Namespace) -> None:
index abdcbf515739e4a75c49f629d34a9c6acc1bc74f..a1149669398e6d4000d5a503792fb5cad6622e65 100644 (file)
@@ -15,6 +15,9 @@ def test_generic_version_systemd() -> None:
     assert not (GenericVersion("1") > GenericVersion("2"))
     assert not (GenericVersion("1") == GenericVersion("2"))
     assert not (GenericVersion("1") >= GenericVersion("2"))
+    assert GenericVersion.compare_versions("1", "2") == 12
+    assert GenericVersion.compare_versions("2", "2") == 0
+    assert GenericVersion.compare_versions("2", "1") == 11
 
 
 def test_generic_version_spec() -> None: