]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
[contrib] Add preprocessor hardwiring to freestanding.py
authorNick Terrell <terrelln@fb.com>
Tue, 11 Aug 2020 06:09:59 +0000 (23:09 -0700)
committerNick Terrell <terrelln@fb.com>
Wed, 9 Sep 2020 21:35:39 +0000 (14:35 -0700)
contrib/freestanding_lib/freestanding.py

index 993626d026591c91b7bbc5f6bca1cb647002fe8b..8165b4a1a1c877b1173aadb6aa1bb021fb736307 100755 (executable)
@@ -17,6 +17,7 @@ import shutil
 import sys
 from typing import Optional
 
+
 INCLUDED_SUBDIRS = ["common", "compress", "decompress"]
 
 SKIPPED_FILES = [
@@ -46,18 +47,360 @@ class FileLines(object):
             f.write("".join(self.lines))
 
 
+class PartialPreprocessor(object):
+    """
+    Looks for simple ifdefs and ifndefs and replaces them.
+    Handles && and ||.
+    Has fancy logic to handle translating elifs to ifs.
+    Only looks for macros in the first part of the expression with no
+    parens.
+    Does not handle multi-line macros (only looks in first line).
+    """
+    def __init__(self, defs: [(str, Optional[str])], replaces: [(str, str)], undefs: [str]):
+        MACRO_GROUP = r"(?P<macro>[a-zA-Z_][a-zA-Z_0-9]*)"
+        ELIF_GROUP = r"(?P<elif>el)?"
+        OP_GROUP = r"(?P<op>&&|\|\|)?"
+
+        self._defs = {macro:value for macro, value in defs}
+        self._replaces = {macro:value for macro, value in replaces}
+        self._defs.update(self._replaces)
+        self._undefs = set(undefs)
+
+        self._define = re.compile(r"\s*#\s*define")
+        self._if = re.compile(r"\s*#\s*if")
+        self._elif = re.compile(r"\s*#\s*(?P<elif>el)if")
+        self._else = re.compile(r"\s*#\s*(?P<else>else)")
+        self._endif = re.compile(r"\s*#\s*endif")
+
+        self._ifdef = re.compile(fr"\s*#\s*if(?P<not>n)?def {MACRO_GROUP}\s*")
+        self._if_defined = re.compile(
+            fr"\s*#\s*{ELIF_GROUP}if\s+(?P<not>!)?\s*defined\s*\(\s*{MACRO_GROUP}\s*\)\s*{OP_GROUP}"
+        )
+        self._if_defined_value = re.compile(
+            fr"\s*#\s*if\s+defined\s*\(\s*{MACRO_GROUP}\s*\)\s*"
+            fr"(?P<op>&&)\s*"
+            fr"(?P<openp>\()?\s*"
+            fr"(?P<macro2>[a-zA-Z_][a-zA-Z_0-9]*)\s*"
+            fr"(?P<cmp>[=><!]+)\s*"
+            fr"(?P<value>[0-9]*)\s*"
+            fr"(?P<closep>\))?\s*"
+        )
+
+        self._c_comment = re.compile(r"/\*.*?\*/")
+        self._cpp_comment = re.compile(r"//")
+
+    def _log(self, *args, **kwargs):
+        print(*args, **kwargs)
+
+    def _strip_comments(self, line):
+        # First strip c-style comments (may include //)
+        while True:
+            m = self._c_comment.search(line)
+            if m is None:
+                break
+            line = line[:m.start()] + line[m.end():]
+
+        # Then strip cpp-style comments
+        m = self._cpp_comment.search(line)
+        if m is not None:
+            line = line[:m.start()]
+
+        return line
+
+    def _fixup_indentation(self, macro, replace: [str]):
+        if len(replace) == 0:
+            return replace
+        if len(replace) == 1 and self._define.match(replace[0]) is None:
+            # If there is only one line, only replace defines
+            return replace
+
+
+        all_pound = True
+        for line in replace:
+            if not line.startswith('#'):
+                all_pound = False
+        if all_pound:
+            replace = [line[1:] for line in replace]
+
+        min_spaces = len(replace[0])
+        for line in replace:
+            spaces = 0
+            for i, c in enumerate(line):
+                if c != ' ':
+                    # Non-preprocessor line ==> skip the fixup
+                    if not all_pound and c != '#':
+                        return replace
+                    spaces = i
+                    break
+            min_spaces = min(min_spaces, spaces)
+
+        replace = [line[min_spaces:] for line in replace]
+
+        if all_pound:
+            replace = ["#" + line for line in replace]
+
+        return replace
+
+    def _handle_if_block(self, macro, idx, is_true, prepend):
+        """
+        Remove the #if or #elif block starting on this line.
+        """
+        REMOVE_ONE = 0
+        KEEP_ONE = 1
+        REMOVE_REST = 2
+
+        if is_true:
+            state = KEEP_ONE
+        else:
+            state = REMOVE_ONE
+
+        line = self._inlines[idx]
+        is_if = self._if.match(line) is not None
+        assert is_if or self._elif.match(line) is not None
+        depth = 0
+
+        start_idx = idx
+
+        idx += 1
+        replace = prepend
+        finished = False
+        while idx < len(self._inlines):
+            line = self._inlines[idx]
+            # Nested if statement
+            if self._if.match(line):
+                depth += 1
+                idx += 1
+                continue
+            # We're inside a nested statement
+            if depth > 0:
+                if self._endif.match(line):
+                    depth -= 1
+                idx += 1
+                continue
+
+            # We're at the original depth
+
+            # Looking only for an endif.
+            # We've found a true statement, but haven't
+            # completely elided the if block, so we just
+            # remove the remainder.
+            if state == REMOVE_REST:
+                if self._endif.match(line):
+                    if is_if:
+                        # Remove the endif because we took the first if
+                        idx += 1
+                    finished = True
+                    break
+                idx += 1
+                continue
+
+            if state == KEEP_ONE:
+                m = self._elif.match(line)
+                if self._endif.match(line):
+                    replace += self._inlines[start_idx + 1:idx]
+                    idx += 1
+                    finished = True
+                    break
+                if self._elif.match(line) or self._else.match(line):
+                    replace += self._inlines[start_idx + 1:idx]
+                    state = REMOVE_REST
+                idx += 1
+                continue
+
+            if state == REMOVE_ONE:
+                m = self._elif.match(line)
+                if m is not None:
+                    if is_if:
+                        idx += 1
+                        b = m.start('elif')
+                        e = m.end('elif')
+                        assert e - b == 2
+                        replace.append(line[:b] + line[e:])
+                    finished = True
+                    break
+                m = self._else.match(line)
+                if m is not None:
+                    if is_if:
+                        idx += 1
+                        while self._endif.match(self._inlines[idx]) is None:
+                            replace.append(self._inlines[idx])
+                            idx += 1
+                        idx += 1
+                    finished = True
+                    break
+                if self._endif.match(line):
+                    if is_if:
+                        # Remove the endif because no other elifs
+                        idx += 1
+                    finished = True
+                    break
+                idx += 1
+                continue
+        if not finished:
+            raise RuntimeError("Unterminated if block!")
+
+        replace = self._fixup_indentation(macro, replace)
+
+        self._log(f"\tHardwiring {macro}")
+        if start_idx > 0:
+            self._log(f"\t\t  {self._inlines[start_idx - 1][:-1]}")
+        for x in range(start_idx, idx):
+            self._log(f"\t\t- {self._inlines[x][:-1]}")
+        for line in replace:
+            self._log(f"\t\t+ {line[:-1]}")
+        if idx < len(self._inlines):
+            self._log(f"\t\t  {self._inlines[idx][:-1]}")
+
+        return idx, replace
+
+    def _preprocess_once(self):
+        outlines = []
+        idx = 0
+        changed = False
+        while idx < len(self._inlines):
+            line = self._inlines[idx]
+            sline = self._strip_comments(line)
+            m = self._ifdef.fullmatch(sline)
+            if m is None:
+                m = self._if_defined_value.fullmatch(sline)
+            if m is None:
+                m = self._if_defined.match(sline)
+            if m is None:
+                outlines.append(line)
+                idx += 1
+                continue
+
+            groups = m.groupdict()
+            macro = groups['macro']
+            ifdef = groups.get('not') is None
+            elseif = groups.get('elif') is not None
+            op = groups.get('op')
+
+            macro2 = groups.get('macro2')
+            cmp = groups.get('cmp')
+            value = groups.get('value')
+            openp = groups.get('openp')
+            closep = groups.get('closep')
+
+            if not (macro in self._defs or macro in self._undefs):
+                outlines.append(line)
+                idx += 1
+                continue
+
+            defined = macro in self._defs
+            is_true = (ifdef == defined)
+            resolved = True
+            if op is not None:
+                if op == '&&':
+                    resolved = not is_true
+                else:
+                    assert op == '||'
+                    resolved = is_true
+
+            if macro2 is not None and not resolved:
+                assert ifdef and defined and op == '&&' and cmp is not None
+                # If the statment is true, but we have a single value check, then
+                # check the value.
+                defined_value = self._defs[macro]
+                are_ints = True
+                try:
+                    defined_value = int(defined_value)
+                    value = int(value)
+                except TypeError:
+                    are_ints = False
+                except ValueError:
+                    are_ints = False
+                if (
+                        macro == macro2 and
+                        ((openp is None) == (closep is None)) and
+                        are_ints
+                ):
+                    resolved = True
+                    if cmp == '<':
+                        is_true = defined_value < value
+                    elif cmp == '<=':
+                        is_true = defined_value <= value
+                    elif cmp == '==':
+                        is_true = defined_value == value
+                    elif cmp == '!=':
+                        is_true = defined_value != value
+                    elif cmp == '>=':
+                        is_true = defined_value >= value
+                    elif cmp == '>':
+                        is_true = defined_value > value
+                    else:
+                        resolved = False
+
+            if op is not None and not resolved:
+                # Remove the first op in the line + spaces
+                if op == '&&':
+                    opre = op
+                else:
+                    assert op == '||'
+                    opre = r'\|\|'
+                needle = re.compile(fr"(?P<if>\s*#\s*(el)?if\s+).*?(?P<op>{opre}\s*)")
+                match = needle.match(line)
+                assert match is not None
+                newline = line[:match.end('if')] + line[match.end('op'):]
+
+                self._log(f"\tHardwiring partially resolved {macro}")
+                self._log(f"\t\t- {line[:-1]}")
+                self._log(f"\t\t+ {newline[:-1]}")
+
+                outlines.append(newline)
+                idx += 1
+                continue
+
+            # Skip any statements we cannot fully compute
+            if not resolved:
+                outlines.append(line)
+                idx += 1
+                continue
+
+            prepend = []
+            if macro in self._replaces:
+                assert not ifdef
+                assert op is None
+                value = self._replaces.pop(macro)
+                prepend = [f"#define {macro} {value}\n"]
+
+            idx, replace = self._handle_if_block(macro, idx, is_true, prepend)
+            outlines += replace
+            changed = True
+
+        return changed, outlines
+
+    def preprocess(self, filename):
+        with open(filename, 'r') as f:
+            self._inlines = f.readlines()
+        changed = True
+        iters = 0
+        while changed:
+            iters += 1
+            changed, outlines = self._preprocess_once()
+            self._inlines = outlines
+
+        with open(filename, 'w') as f:
+            f.write(''.join(self._inlines))
+
+
 class Freestanding(object):
     def __init__(
             self,zstd_deps: str, source_lib: str, output_lib: str,
-            external_xxhash: bool, rewritten_includes: [(str, str)],
-            defs: [(str, Optional[str])], undefs: [str], excludes: [str]
+            external_xxhash: bool, xxh64_state: Optional[str],
+            xxh64_prefix: Optional[str], rewritten_includes: [(str, str)],
+            defs: [(str, Optional[str])], replaces: [(str, str)],
+            undefs: [str], excludes: [str]
     ):
         self._zstd_deps = zstd_deps
         self._src_lib = source_lib
         self._dst_lib = output_lib
         self._external_xxhash = external_xxhash
+        self._xxh64_state = xxh64_state
+        self._xxh64_prefix = xxh64_prefix
         self._rewritten_includes = rewritten_includes
         self._defs = defs
+        self._replaces = replaces
         self._undefs = undefs
         self._excludes = excludes
 
@@ -121,14 +464,10 @@ class Freestanding(object):
             file = FileLines(filepath)
     
     def _hardwire_defines(self):
-        self._log("Hardwiring defined macros")
-        for (name, value) in self._defs:
-            self._log(f"\tHardwiring: #define {name} {value}")
-            self._hardwire_preprocessor(name, value=value)
-        self._log("Hardwiring undefined macros")
-        for name in self._undefs:
-            self._log(f"\tHardwiring: #undef {name}")
-            self._hardwire_preprocessor(name, undef=True)
+        self._log("Hardwiring macros")
+        partial_preprocessor = PartialPreprocessor(self._defs, self._replaces, self._undefs)
+        for filepath in self._dst_lib_file_paths():
+            partial_preprocessor.preprocess(filepath)
 
     def _remove_excludes(self):
         self._log("Removing excluded sections")
@@ -180,15 +519,47 @@ class Freestanding(object):
         for original, rewritten in self._rewritten_includes:
             self._rewrite_include(original, rewritten)
     
+    def _replace_xxh64_prefix(self):
+        if self._xxh64_prefix is None:
+            return
+        self._log(f"Replacing XXH64 prefix with {self._xxh64_prefix}")
+        replacements = []
+        if self._xxh64_state is not None:
+            replacements.append(
+                (re.compile(r"([^\w]|^)(?P<orig>XXH64_state_t)([^\w]|$)"), self._xxh64_state)
+            )
+        if self._xxh64_prefix is not None:
+            replacements.append(
+                (re.compile(r"([^\w]|^)(?P<orig>XXH64)_"), self._xxh64_prefix)
+            )
+        for filepath in self._dst_lib_file_paths():
+            file = FileLines(filepath)
+            for i, line in enumerate(file.lines):
+                modified = False
+                for regex, replacement in replacements:
+                    match = regex.search(line)
+                    while match is not None:
+                        modified = True
+                        b = match.start('orig')
+                        e = match.end('orig')
+                        line = line[:b] + replacement + line[e:]
+                        match = regex.search(line)
+                if modified:
+                    self._log(f"\t- {file.lines[i][:-1]}")
+                    self._log(f"\t+ {line[:-1]}")
+                file.lines[i] = line
+            file.write()
+
     def go(self):
         self._copy_source_lib()
         self._copy_zstd_deps()
         self._hardwire_defines()
         self._remove_excludes()
         self._rewrite_includes()
+        self._replace_xxh64_prefix()
 
 
-def parse_defines(defines: [str]) -> [(str, Optional[str])]:
+def parse_optional_pair(defines: [str]) -> [(str, Optional[str])]:
     output = []
     for define in defines:
         parsed = define.split('=')
@@ -201,7 +572,7 @@ def parse_defines(defines: [str]) -> [(str, Optional[str])]:
     return output
 
 
-def parse_rewritten_includes(rewritten_includes: [str]) -> [(str, str)]:
+def parse_pair(rewritten_includes: [str]) -> [(str, str)]:
     output = []
     for rewritten_include in rewritten_includes:
         parsed = rewritten_include.split('=')
@@ -219,9 +590,12 @@ def main(name, args):
     parser.add_argument("--source-lib", default="../../lib", help="Location of the zstd library")
     parser.add_argument("--output-lib", default="./freestanding_lib", help="Where to output the freestanding zstd library")
     parser.add_argument("--xxhash", default=None, help="Alternate external xxhash include e.g. --xxhash='<xxhash.h>'. If set xxhash is not included.")
+    parser.add_argument("--xxh64-state", default=None, help="Alternate XXH64 state type (excluding _) e.g. --xxh64-state='struct xxh64_state'")
+    parser.add_argument("--xxh64-prefix", default=None, help="Alternate XXH64 function prefix (excluding _) e.g. --xxh64-prefix=xxh64")
     parser.add_argument("--rewrite-include", default=[], dest="rewritten_includes", action="append", help="Rewrite an include REGEX=NEW (e.g. '<stddef\\.h>=<linux/types.h>')")
     parser.add_argument("-D", "--define", default=[], dest="defs", action="append", help="Pre-define this macro (can be passed multiple times)")
     parser.add_argument("-U", "--undefine", default=[], dest="undefs", action="append", help="Pre-undefine this macro (can be passed mutliple times)")
+    parser.add_argument("-R", "--replace", default=[], dest="replaces", action="append", help="Pre-define this macro and replace the first ifndef block with its definition")
     parser.add_argument("-E", "--exclude", default=[], dest="excludes", action="append", help="Exclude all lines between 'BEGIN <EXCLUDE>' and 'END <EXCLUDE>'")
     args = parser.parse_args(args)
 
@@ -229,22 +603,37 @@ def main(name, args):
     if "ZSTD_MULTITHREAD" not in args.undefs:
         args.undefs.append("ZSTD_MULTITHREAD")
 
-    args.defs = parse_defines(args.defs)
+    args.defs = parse_optional_pair(args.defs)
     for name, _ in args.defs:
         if name in args.undefs:
             raise RuntimeError(f"{name} is both defined and undefined!")
 
-    args.rewritten_includes = parse_rewritten_includes(args.rewritten_includes)
+    args.replaces = parse_pair(args.replaces)
+    for name, _ in args.replaces:
+        if name in args.undefs or name in args.defs:
+            raise RuntimeError(f"{name} is both replaced and (un)defined!")
+
+    args.rewritten_includes = parse_pair(args.rewritten_includes)
 
     external_xxhash = False
     if args.xxhash is not None:
         external_xxhash = True
         args.rewritten_includes.append(('"(\\.\\./common/)?xxhash.h"', args.xxhash))
 
+    if args.xxh64_prefix is not None:
+        if not external_xxhash:
+            raise RuntimeError("--xxh64-prefix may only be used with --xxhash provided")
+
+    if args.xxh64_state is not None:
+        if not external_xxhash:
+            raise RuntimeError("--xxh64-state may only be used with --xxhash provided")
+
     print(args.zstd_deps)
     print(args.output_lib)
     print(args.source_lib)
     print(args.xxhash)
+    print(args.xxh64_state)
+    print(args.xxh64_prefix)
     print(args.rewritten_includes)
     print(args.defs)
     print(args.undefs)
@@ -254,8 +643,11 @@ def main(name, args):
         args.source_lib,
         args.output_lib,
         external_xxhash,
+        args.xxh64_state,
+        args.xxh64_prefix,
         args.rewritten_includes,
         args.defs,
+        args.replaces,
         args.undefs,
         args.excludes
     ).go()