# along with GCC; see the file COPYING3. If not see
# <http://www.gnu.org/licenses/>.
-# TODO: Extract riscv_subset_t from riscv-common.cc and make it can be compiled
-# standalone to replace this script, that also prevents us implementing
-# that twice and keep sync again and again.
-
from __future__ import print_function
import sys
import argparse
import collections
import itertools
+import re
+import os
from functools import reduce
SUPPORTED_ISA_SPEC = ["2.2", "20190608", "20191213"]
CANONICAL_ORDER = "imafdqlcbkjtpvnh"
LONG_EXT_PREFIXES = ['z', 's', 'h', 'x']
+def parse_define_riscv_ext(content):
+ """Parse DEFINE_RISCV_EXT macros using position-based parsing."""
+ extensions = []
+
+ # Find all DEFINE_RISCV_EXT blocks
+ pattern = r'DEFINE_RISCV_EXT\s*\('
+ matches = []
+
+ pos = 0
+ while True:
+ match = re.search(pattern, content[pos:])
+ if not match:
+ break
+
+ start_pos = pos + match.start()
+ paren_count = 0
+ current_pos = pos + match.end() - 1 # Start at the opening parenthesis
+
+ # Find the matching closing parenthesis
+ while current_pos < len(content):
+ if content[current_pos] == '(':
+ paren_count += 1
+ elif content[current_pos] == ')':
+ paren_count -= 1
+ if paren_count == 0:
+ break
+ current_pos += 1
+
+ if paren_count == 0:
+ # Extract the content inside parentheses
+ macro_content = content[pos + match.end():current_pos]
+ ext_data = parse_macro_arguments(macro_content)
+ if ext_data:
+ extensions.append(ext_data)
+
+ pos = current_pos + 1
+
+ return extensions
+
+def parse_macro_arguments(macro_content):
+ """Parse the arguments of a DEFINE_RISCV_EXT macro."""
+ # Remove comments /* ... */
+ cleaned_content = re.sub(r'/\*[^*]*\*/', '', macro_content)
+
+ # Split arguments by comma, but respect nested structures
+ args = []
+ current_arg = ""
+ paren_count = 0
+ brace_count = 0
+ in_string = False
+ escape_next = False
+
+ for char in cleaned_content:
+ if escape_next:
+ current_arg += char
+ escape_next = False
+ continue
+
+ if char == '\\':
+ escape_next = True
+ current_arg += char
+ continue
+
+ if char == '"' and not escape_next:
+ in_string = not in_string
+ current_arg += char
+ continue
+
+ if in_string:
+ current_arg += char
+ continue
+
+ if char == '(':
+ paren_count += 1
+ elif char == ')':
+ paren_count -= 1
+ elif char == '{':
+ brace_count += 1
+ elif char == '}':
+ brace_count -= 1
+ elif char == ',' and paren_count == 0 and brace_count == 0:
+ args.append(current_arg.strip())
+ current_arg = ""
+ continue
+
+ current_arg += char
+
+ # Add the last argument
+ if current_arg.strip():
+ args.append(current_arg.strip())
+
+ # We need at least 6 arguments to get DEP_EXTS (position 5)
+ if len(args) < 6:
+ return None
+
+ ext_name = args[0].strip()
+ dep_exts_arg = args[5].strip() # DEP_EXTS is at position 5
+
+ # Parse dependency extensions from the DEP_EXTS argument
+ deps = parse_dep_exts(dep_exts_arg)
+
+ return {
+ 'name': ext_name,
+ 'dep_exts': deps
+ }
+
+def parse_dep_exts(dep_exts_str):
+ """Parse the DEP_EXTS argument to extract dependency list with conditions."""
+ # Remove outer parentheses if present
+ dep_exts_str = dep_exts_str.strip()
+ if dep_exts_str.startswith('(') and dep_exts_str.endswith(')'):
+ dep_exts_str = dep_exts_str[1:-1].strip()
+
+ # Remove outer braces if present
+ if dep_exts_str.startswith('{') and dep_exts_str.endswith('}'):
+ dep_exts_str = dep_exts_str[1:-1].strip()
+
+ if not dep_exts_str:
+ return []
+
+ deps = []
+
+ # First, find and process conditional dependencies
+ conditional_pattern = r'\{\s*"([^"]+)"\s*,\s*(\[.*?\]\s*\([^)]*\)\s*->\s*bool.*?)\}'
+ conditional_matches = []
+
+ for match in re.finditer(conditional_pattern, dep_exts_str, re.DOTALL):
+ ext_name = match.group(1)
+ condition_code = match.group(2)
+ deps.append({'ext': ext_name, 'type': 'conditional', 'condition': condition_code})
+ conditional_matches.append((match.start(), match.end()))
+
+ # Remove conditional dependency blocks from the string
+ remaining_str = dep_exts_str
+ for start, end in reversed(conditional_matches): # Reverse order to maintain indices
+ remaining_str = remaining_str[:start] + remaining_str[end:]
+
+ # Now handle simple quoted strings in the remaining text
+ for match in re.finditer(r'"([^"]+)"', remaining_str):
+ deps.append({'ext': match.group(1), 'type': 'simple'})
+
+ # Remove duplicates while preserving order
+ seen = set()
+ unique_deps = []
+ for dep in deps:
+ key = (dep['ext'], dep['type'])
+ if key not in seen:
+ seen.add(key)
+ unique_deps.append(dep)
+
+ return unique_deps
+
+def evaluate_conditional_dependency(ext, dep, xlen, current_exts):
+ """Evaluate whether a conditional dependency should be included."""
+ ext_name = dep['ext']
+ condition = dep['condition']
+ # Parse the condition based on known patterns
+ if ext_name == 'zcf' and ext in ['zca', 'c', 'zce']:
+ # zcf depends on RV32 and F extension
+ return xlen == 32 and 'f' in current_exts
+ elif ext_name == 'zcd' and ext in ['zca', 'c']:
+ # zcd depends on D extension
+ return 'd' in current_exts
+ elif ext_name == 'c' and ext in ['zca']:
+ # Special case for zca -> c conditional dependency
+ if xlen == 32:
+ if 'd' in current_exts:
+ return 'zcf' in current_exts and 'zcd' in current_exts
+ elif 'f' in current_exts:
+ return 'zcf' in current_exts
+ else:
+ return True
+ elif xlen == 64:
+ if 'd' in current_exts:
+ return 'zcd' in current_exts
+ else:
+ return True
+ return False
+ else:
+ # Report error for unhandled conditional dependencies
+ import sys
+ print(f"ERROR: Unhandled conditional dependency: '{ext_name}' with condition:", file=sys.stderr)
+ print(f" Condition code: {condition[:100]}...", file=sys.stderr)
+ print(f" Current context: xlen={xlen}, exts={sorted(current_exts)}", file=sys.stderr)
+ # For now, return False to be safe
+ return False
+
+def resolve_dependencies(arch_parts, xlen):
+ """Resolve all dependencies including conditional ones."""
+ current_exts = set(arch_parts)
+ implied_deps = set()
+
+ # Keep resolving until no new dependencies are found
+ changed = True
+ while changed:
+ changed = False
+ new_deps = set()
+
+ for ext in current_exts | implied_deps:
+ if ext in IMPLIED_EXT:
+ for dep in IMPLIED_EXT[ext]:
+ if dep['type'] == 'simple':
+ if dep['ext'] not in current_exts and dep['ext'] not in implied_deps:
+ new_deps.add(dep['ext'])
+ changed = True
+ elif dep['type'] == 'conditional':
+ should_include = evaluate_conditional_dependency(ext, dep, xlen, current_exts | implied_deps)
+ if should_include:
+ if dep['ext'] not in current_exts and dep['ext'] not in implied_deps:
+ new_deps.add(dep['ext'])
+ changed = True
+
+ implied_deps.update(new_deps)
+
+ return implied_deps
+
+def parse_def_file(file_path, script_dir, processed_files=None, collect_all=False):
+ """Parse a single .def file and recursively process #include directives."""
+ if processed_files is None:
+ processed_files = set()
+
+ # Avoid infinite recursion
+ if file_path in processed_files:
+ return ({}, set()) if collect_all else {}
+ processed_files.add(file_path)
+
+ implied_ext = {}
+ all_extensions = set() if collect_all else None
+
+ if not os.path.exists(file_path):
+ return (implied_ext, all_extensions) if collect_all else implied_ext
+
+ with open(file_path, 'r') as f:
+ content = f.read()
+
+ # Process #include directives first
+ include_pattern = r'#include\s+"([^"]+)"'
+ includes = re.findall(include_pattern, content)
+
+ for include_file in includes:
+ include_path = os.path.join(script_dir, include_file)
+ if collect_all:
+ included_ext, included_all = parse_def_file(include_path, script_dir, processed_files, collect_all)
+ implied_ext.update(included_ext)
+ all_extensions.update(included_all)
+ else:
+ included_ext = parse_def_file(include_path, script_dir, processed_files, collect_all)
+ implied_ext.update(included_ext)
+
+ # Parse DEFINE_RISCV_EXT blocks using position-based parsing
+ parsed_exts = parse_define_riscv_ext(content)
+
+ for ext_data in parsed_exts:
+ ext_name = ext_data['name']
+ deps = ext_data['dep_exts']
+
+ if collect_all:
+ all_extensions.add(ext_name)
+
+ if deps:
+ implied_ext[ext_name] = deps
+
+ return (implied_ext, all_extensions) if collect_all else implied_ext
+
+def parse_def_files():
+ """Parse RISC-V extension definition files starting from riscv-ext.def."""
+ # Get directory containing this script
+ try:
+ script_dir = os.path.dirname(os.path.abspath(__file__))
+ except NameError:
+ # When __file__ is not defined (e.g., interactive mode)
+ script_dir = os.getcwd()
+
+ # Start with the main definition file
+ main_def_file = os.path.join(script_dir, 'riscv-ext.def')
+ return parse_def_file(main_def_file, script_dir)
+
+def get_all_extensions():
+ """Get all supported extensions and their implied extensions."""
+ # Get directory containing this script
+ try:
+ script_dir = os.path.dirname(os.path.abspath(__file__))
+ except NameError:
+ # When __file__ is not defined (e.g., interactive mode)
+ script_dir = os.getcwd()
+
+ # Start with the main definition file
+ main_def_file = os.path.join(script_dir, 'riscv-ext.def')
+ return parse_def_file(main_def_file, script_dir, collect_all=True)
+
#
# IMPLIED_EXT(ext) -> implied extension list.
+# This is loaded dynamically from .def files
#
-IMPLIED_EXT = {
- "d" : ["f", "zicsr"],
-
- "a" : ["zaamo", "zalrsc"],
- "zabha" : ["zaamo"],
- "zacas" : ["zaamo"],
-
- "f" : ["zicsr"],
- "b" : ["zba", "zbb", "zbs"],
- "zdinx" : ["zfinx", "zicsr"],
- "zfinx" : ["zicsr"],
- "zhinx" : ["zhinxmin", "zfinx", "zicsr"],
- "zhinxmin" : ["zfinx", "zicsr"],
-
- "zk" : ["zkn", "zkr", "zkt"],
- "zkn" : ["zbkb", "zbkc", "zbkx", "zkne", "zknd", "zknh"],
- "zks" : ["zbkb", "zbkc", "zbkx", "zksed", "zksh"],
-
- "v" : ["zvl128b", "zve64d"],
- "zve32x" : ["zvl32b"],
- "zve64x" : ["zve32x", "zvl64b"],
- "zve32f" : ["f", "zve32x"],
- "zve64f" : ["f", "zve32f", "zve64x"],
- "zve64d" : ["d", "zve64f"],
-
- "zvl64b" : ["zvl32b"],
- "zvl128b" : ["zvl64b"],
- "zvl256b" : ["zvl128b"],
- "zvl512b" : ["zvl256b"],
- "zvl1024b" : ["zvl512b"],
- "zvl2048b" : ["zvl1024b"],
- "zvl4096b" : ["zvl2048b"],
- "zvl8192b" : ["zvl4096b"],
- "zvl16384b" : ["zvl8192b"],
- "zvl32768b" : ["zvl16384b"],
- "zvl65536b" : ["zvl32768b"],
-
- "zvkn" : ["zvkned", "zvknhb", "zvkb", "zvkt"],
- "zvknc" : ["zvkn", "zvbc"],
- "zvkng" : ["zvkn", "zvkg"],
- "zvks" : ["zvksed", "zvksh", "zvkb", "zvkt"],
- "zvksc" : ["zvks", "zvbc"],
- "zvksg" : ["zvks", "zvkg"],
- "zvbb" : ["zvkb"],
- "zvbc" : ["zve64x"],
- "zvkb" : ["zve32x"],
- "zvkg" : ["zve32x"],
- "zvkned" : ["zve32x"],
- "zvknha" : ["zve32x"],
- "zvknhb" : ["zve64x"],
- "zvksed" : ["zve32x"],
- "zvksh" : ["zve32x"],
-}
+IMPLIED_EXT = parse_def_files()
def arch_canonicalize(arch, isa_spec):
# TODO: Support extension version.
long_exts += extra_long_ext
#
- # Handle implied extensions.
+ # Handle implied extensions using new conditional logic.
#
- any_change = True
- while any_change:
- any_change = False
- for ext in std_exts + long_exts:
- if ext in IMPLIED_EXT:
- implied_exts = IMPLIED_EXT[ext]
- for implied_ext in implied_exts:
- if implied_ext == 'zicsr' and is_isa_spec_2p2:
- continue
+ # Extract xlen from architecture string
+ # TODO: We should support profile here.
+ if arch.startswith('rv32'):
+ xlen = 32
+ elif arch.startswith('rv64'):
+ xlen = 64
+ else:
+ raise Exception("Unsupported prefix `%s`" % arch)
- if implied_ext not in std_exts + long_exts:
- long_exts.append(implied_ext)
- any_change = True
+ # Get all current extensions
+ current_exts = std_exts + long_exts
+
+ # Resolve dependencies
+ implied_deps = resolve_dependencies(current_exts, xlen)
+
+ # Filter out zicsr for ISA spec 2.2
+ if is_isa_spec_2p2:
+ implied_deps.discard('zicsr')
+
+ # Add implied dependencies to long_exts
+ for dep in implied_deps:
+ if dep not in current_exts:
+ long_exts.append(dep)
# Single letter extension might appear in the long_exts list,
# because we just append extensions list to the arch string.
return new_arch
-if len(sys.argv) < 2:
- print ("Usage: %s <arch_str> [<arch_str>*]" % sys.argv)
- sys.exit(1)
+def dump_all_extensions():
+ """Dump all extensions and their implied extensions."""
+ implied_ext, all_extensions = get_all_extensions()
-parser = argparse.ArgumentParser()
-parser.add_argument('-misa-spec', type=str,
- default='20191213',
- choices=SUPPORTED_ISA_SPEC)
-parser.add_argument('arch_strs', nargs=argparse.REMAINDER)
+ print("All supported RISC-V extensions:")
+ print("=" * 60)
-args = parser.parse_args()
+ if not all_extensions:
+ print("No extensions found.")
+ return
-for arch in args.arch_strs:
- print (arch_canonicalize(arch, args.misa_spec))
+ # Sort all extensions for consistent output
+ sorted_all = sorted(all_extensions)
+
+ # Print all extensions with their dependencies (if any)
+ for ext_name in sorted_all:
+ if ext_name in implied_ext:
+ deps = implied_ext[ext_name]
+ dep_strs = []
+ for dep in deps:
+ if dep['type'] == 'simple':
+ dep_strs.append(dep['ext'])
+ else:
+ dep_strs.append(f"{dep['ext']}*") # Mark conditional deps with *
+ print(f"{ext_name:15} -> {', '.join(dep_strs)}")
+ else:
+ print(f"{ext_name:15} -> (no dependencies)")
+
+ print(f"\nTotal extensions: {len(all_extensions)}")
+ print(f"Extensions with dependencies: {len(implied_ext)}")
+ print(f"Extensions without dependencies: {len(all_extensions) - len(implied_ext)}")
+
+def run_unit_tests():
+ """Run unit tests using pytest dynamically imported."""
+ try:
+ import pytest
+ except ImportError:
+ print("Error: pytest is required for running unit tests.")
+ print("Please install pytest: pip install pytest")
+ return 1
+
+ # Define test functions
+ def test_basic_arch_parsing():
+ """Test basic architecture string parsing."""
+ result = arch_canonicalize("rv64i", "20191213")
+ assert result == "rv64i"
+
+ def test_simple_extensions():
+ """Test simple extension handling."""
+ result = arch_canonicalize("rv64im", "20191213")
+ assert "zmmul" in result
+
+ def test_implied_extensions():
+ """Test implied extension resolution."""
+ result = arch_canonicalize("rv64imaf", "20191213")
+ assert "zicsr" in result
+
+ def test_conditional_dependencies():
+ """Test conditional dependency evaluation."""
+ # Test RV32 with F extension should include zcf when c is present
+ result = arch_canonicalize("rv32ifc", "20191213")
+ parts = result.split("_")
+ if "c" in parts:
+ assert "zca" in parts
+ if "f" in parts:
+ assert "zcf" in parts
+
+ def test_parse_dep_exts():
+ """Test dependency parsing function."""
+ # Test simple dependency
+ deps = parse_dep_exts('{"ext1", "ext2"}')
+ assert len(deps) == 2
+ assert deps[0]['ext'] == 'ext1'
+ assert deps[0]['type'] == 'simple'
+
+ def test_evaluate_conditional_dependency():
+ """Test conditional dependency evaluation."""
+ # Test zcf condition for RV32 with F
+ dep = {'ext': 'zcf', 'type': 'conditional', 'condition': 'test'}
+ result = evaluate_conditional_dependency('zce', dep, 32, {'f'})
+ assert result == True
+
+ # Test zcf condition for RV64 with F (should be False)
+ result = evaluate_conditional_dependency('zce', dep, 64, {'f'})
+ assert result == False
+
+ def test_parse_define_riscv_ext():
+ """Test DEFINE_RISCV_EXT parsing."""
+ content = '''
+ DEFINE_RISCV_EXT(
+ /* NAME */ test,
+ /* UPPERCASE_NAME */ TEST,
+ /* FULL_NAME */ "Test extension",
+ /* DESC */ "",
+ /* URL */ ,
+ /* DEP_EXTS */ ({"dep1", "dep2"}),
+ /* SUPPORTED_VERSIONS */ ({{1, 0}}),
+ /* FLAG_GROUP */ test,
+ /* BITMASK_GROUP_ID */ 0,
+ /* BITMASK_BIT_POSITION*/ 0,
+ /* EXTRA_EXTENSION_FLAGS */ 0)
+ '''
+
+ extensions = parse_define_riscv_ext(content)
+ assert len(extensions) == 1
+ assert extensions[0]['name'] == 'test'
+ assert len(extensions[0]['dep_exts']) == 2
+
+ # Collect test functions
+ test_functions = [
+ test_basic_arch_parsing,
+ test_simple_extensions,
+ test_implied_extensions,
+ test_conditional_dependencies,
+ test_parse_dep_exts,
+ test_evaluate_conditional_dependency,
+ test_parse_define_riscv_ext
+ ]
+
+ # Run tests manually first, then optionally with pytest
+ print("Running unit tests...")
+
+ passed = 0
+ failed = 0
+
+ for i, test_func in enumerate(test_functions):
+ try:
+ print(f" Running {test_func.__name__}...", end=" ")
+ test_func()
+ print("PASSED")
+ passed += 1
+ except Exception as e:
+ print(f"FAILED: {e}")
+ failed += 1
+
+ print(f"\nTest Summary: {passed} passed, {failed} failed")
+
+ if failed == 0:
+ print("\nAll tests passed!")
+ return 0
+ else:
+ print(f"\n{failed} test(s) failed!")
+ return 1
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-misa-spec', type=str,
+ default='20191213',
+ choices=SUPPORTED_ISA_SPEC)
+ parser.add_argument('--dump-all', action='store_true',
+ help='Dump all extensions and their implied extensions')
+ parser.add_argument('--selftest', action='store_true',
+ help='Run unit tests using pytest')
+ parser.add_argument('arch_strs', nargs='*',
+ help='Architecture strings to canonicalize')
+
+ args = parser.parse_args()
+
+ if args.dump_all:
+ dump_all_extensions()
+ elif args.selftest:
+ sys.exit(run_unit_tests())
+ elif args.arch_strs:
+ for arch in args.arch_strs:
+ print (arch_canonicalize(arch, args.misa_spec))
+ else:
+ parser.print_help()
+ sys.exit(1)