#!/usr/bin/env python3 # libmisc/wrap-cc - Wrapper around GCC to enhance the preprocessor # # Copyright (C) 2025 Luke T. Shumaker # SPDX-License-Identifier: AGPL-3.0-or-later import os import re import subprocess import sys import typing def scan_tuple( text: str, beg: int, on_part: typing.Callable[[str], None] | None = None ) -> int: assert text[beg] == "(" pos = beg + 1 arg_start = pos parens = 1 instring = False while parens: c = text[pos] if instring: match c: case "\\": pos += 1 case '"': instring = False else: match c: case "(": parens += 1 case ")": parens -= 1 if on_part and parens == 0 and text[beg + 1 : pos].strip(): on_part(text[arg_start:pos]) case ",": if on_part and parens == 1: on_part(text[arg_start:pos]) arg_start = pos + 1 case '"': instring = True pos += 1 assert text[pos - 1] == ")" return pos - 1 def unquote(cstr: str) -> str: assert len(cstr) >= 2 and cstr[0] == '"' and cstr[-1] == '"' cstr = cstr[1:-1] out = "" while cstr: if cstr[0] == "\\": match cstr[1]: case "n": out += "\n" cstr = cstr[2:] case "\\": out += "\\" cstr = cstr[2:] case '"': out += '"' cstr = cstr[2:] else: out += cstr[0] cstr = cstr[1:] return out def preprocess(all_args: list[str]) -> typing.NoReturn: # argparse ################################################################# _args = all_args def shift(n: int) -> list[str]: nonlocal _args ret = _args[:n] _args = _args[n:] return ret arg0 = shift(1)[0] common_flags: list[str] = [] output_flags: list[str] = [] positional: list[str] = [] while _args: if len(_args[0]) > 2 and _args[0][0] == "-" and _args[0][1] in "IDU": _args = [_args[0][:2], _args[0][2:], *_args[1:]] match _args[0]: # Mode case "-E" | "-quiet": common_flags += shift(1) case "-lang-asm": os.execvp(all_args[0], all_args) # Search path case "-I" | "-imultilib" | "-isystem": common_flags += shift(2) # Define/Undefine case "-D" | "-U": common_flags += shift(2) # Optimization case "-O0" | "-O1" | "-O2" | "-O3" | "-Os" | "-Ofast" | "-Og" | "-Oz": common_flags += shift(1) case "-g": common_flags += shift(1) # Output files case "-MD" | "-MF" | "-MT" | "-dumpbase" | "-dumpbase-ext": output_flags += shift(2) case "-o": output_flags += shift(2) # Other case _: if _args[0].startswith("-"): if _args[0].startswith("-std="): common_flags += shift(1) elif _args[0].startswith("-m"): common_flags += shift(1) elif _args[0].startswith("-f"): common_flags += shift(1) elif _args[0].startswith("-W"): common_flags += shift(1) else: raise ValueError(f"unknown flag: {_args!r}") else: positional += shift(1) if len(positional) != 1: raise ValueError("expected 1 input file") infile = positional[0] # enhance ################################################################## common_flags += ["-D", "__LIBMISC_ENHANCED_CPP__"] text = subprocess.run( [arg0, *common_flags, infile], stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stderr=sys.stderr, check=True, text=True, ).stdout macros: dict[str, str] = {} marker = "__xx__LM_DEFAPPEND__xx__" pos = 0 while (marker_beg := text.find(marker, pos)) >= 0: args: list[str] = [] def add_arg(arg: str) -> None: nonlocal args args.append(arg) beg_paren = marker_beg + len(marker) end_paren = scan_tuple(text, beg_paren, add_arg) before = text[:marker_beg] # old = text[marker_beg : end_paren + 1] after = text[end_paren + 1 :] assert len(args) == 2 k = unquote(args[0].strip()) v = unquote(args[1].strip()) if k not in macros: macros[k] = v else: macros[k] += " " + v text = before + after pos = len(before) common_flags += ["-D", marker + "=LM_EAT"] for k, v in macros.items(): common_flags += ["-D", k + "=" + v] # Run, for-real ############################################################ os.execvp(arg0, [arg0, *common_flags, *output_flags, infile]) def cpp_squash_linemarkers(text: str) -> str: out = "" buf_marker = "" buf_body = "" for line in text.splitlines(keepends=True): if line.startswith("# "): if buf_marker != line or buf_body.strip(): out += buf_marker if buf_body.strip(): out += buf_body buf_marker = line buf_body = "" else: buf_body += line if buf_body.strip(): out += buf_marker out += buf_body return out class Preprocessor: _cpp: list[str] _builtin_defines: list[str] | None = None def __init__(self, cpp: list[str]) -> None: self._cpp = cpp self._builtin_defines = None @property def builtin_defines(self) -> list[str]: if self._builtin_defines is None: self._builtin_defines = subprocess.run( [*self._cpp, "-quiet", "-undef", "-nostdinc", "-dM"], stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stderr=sys.stderr, check=True, text=True, ).stdout.splitlines(keepends=True) return self._builtin_defines def process_file(self, infile: str, flags: list[str]) -> str: # First/main preprocessor pass. text = subprocess.run( [*self._cpp, "-dD", *flags, infile], stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stderr=sys.stderr, check=True, text=True, ).stdout # Extra (subclass) processing. text = self.extra(text) # Split the combined "-dD" output into "-dM" output and normal # output. macro_text = "" normal_text = "" for line in text.splitlines(keepends=True): if line.startswith("#define ") or line.startswith("#undef "): macro_text += line normal_text += "\n" else: normal_text += line return normal_text def process_fragment(self, before: str, text: str) -> str: before_lines = before.splitlines(keepends=True) builtin_defines = self.builtin_defines prefix = "".join( [ line for line in before_lines if line.startswith("#") and line not in builtin_defines ] ) prefix = cpp_squash_linemarkers(prefix) text = subprocess.run( [ *self._cpp, "-quiet", "-undef", "-nostdinc", "-dD", ], input=prefix + text, stdout=subprocess.PIPE, stderr=sys.stderr, check=True, text=True, ).stdout text = cpp_squash_linemarkers(text) text = text[text.index(prefix) + len(prefix) :] if text.startswith("# "): text = "\n" + text text = self.extra(text) return text def extra(self, text: str) -> str: return text ################################################################################ class EnhancedPreprocessor(Preprocessor): def process_file(self, infile: str, flags: list[str]) -> str: return super().process_file(infile, ["-D__LIBMISC_ENHANCED_CPP__", *flags]) special_macros = [ "LM_EVAL", "LM_FOREACH_PARAM", "LM_FOREACH_TUPLE", ] re_special = re.compile( r"__xx_(?P" + "|".join([re.escape(m) for m in special_macros]) + r")_xx__\(" ) def extra(self, text: str) -> str: pos = 0 while intro := self.re_special.search(text, pos): nl = text.rfind("\n", 0, intro.start()) if text[nl + 1] == "#": pos = intro.end() continue macro = intro.group("macro") args: list[str] = [] def add_arg(arg: str) -> None: args.append(arg) beg_paren = intro.end() - 1 end_paren = cpp_scan_tuple(text, beg_paren, add_arg) before = text[: intro.start()] # old = text[intro.start() : end_paren + 1] after = text[end_paren + 1 :] new = self._eval_macro(before, macro, args) text = before + new + after pos = len(before) + len(new) return text def _eval_macro(self, before: str, macro: str, args: list[str]) -> str: match macro: case "LM_EVAL": # LM_EVAL(...) ret = ",".join(args) while True: ret2 = self.process_fragment(before, ret) if ret2 == ret: break ret = ret2 return ret case "LM_FOREACH_PARAM": # LM_FOREACH_PARAM(func, (fixedparams), params...) assert len(args) >= 2 func = args[0].strip() fixedparams = args[1].strip()[1:-1].strip() if fixedparams: fixedparams += ", " ret = "" for param in args[2:]: ret += f"{func}({fixedparams}{param})" return ret case "LM_FOREACH_TUPLE": # LM_FOREACH_TUPLE(tuples, func, fixedparams...) tuples_str = args[0].lstrip() func = args[1].strip() fixedparams = "".join([a.strip() + ", " for a in args[2:]]) ret = "" while tuples_str: end_paren = cpp_scan_tuple(tuples_str, 0) tup = tuples_str[1:end_paren] ret += f"{func}({fixedparams}{tup})" tuples_str = tuples_str[end_paren + 1 :].lstrip() return ret case _: raise ValueError(f"unknown macro: {macro}") ################################################################################ def main(all_args: list[str]) -> typing.NoReturn: if len(all_args) >= 2 and all_args[0].endswith("cc1") and all_args[1] == "-E": preprocess(all_args) else: os.execvp(all_args[0], all_args) if __name__ == "__main__": main(sys.argv[1:])