summaryrefslogtreecommitdiff
path: root/lib9p/idl/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib9p/idl/__init__.py')
-rw-r--r--lib9p/idl/__init__.py305
1 files changed, 170 insertions, 135 deletions
diff --git a/lib9p/idl/__init__.py b/lib9p/idl/__init__.py
index a01c38f..e7b3670 100644
--- a/lib9p/idl/__init__.py
+++ b/lib9p/idl/__init__.py
@@ -6,8 +6,9 @@
import enum
import os.path
import re
-from typing import Callable, Literal, TypeVar, cast
+import typing
+# pylint: disable=unused-variable
__all__ = [
# entrypoint
"Parser",
@@ -15,7 +16,7 @@ __all__ = [
"Type",
"Primitive",
"Number",
- *["Bitfield", "BitfieldVal"],
+ *["Bitfield", "Bit", "BitCat", "BitAlias"],
*["Struct", "StructMember", "Expr", "ExprOp", "ExprSym", "ExprLit"],
"Message",
]
@@ -36,7 +37,7 @@ class Primitive(enum.Enum):
return set()
@property
- def name(self) -> str:
+ def typname(self) -> str:
return str(self.value)
@property
@@ -51,7 +52,7 @@ class Primitive(enum.Enum):
class Number:
- name: str
+ typname: str
in_versions: set[str]
prim: Primitive
@@ -73,27 +74,49 @@ class Number:
return self.static_size
-class BitfieldVal:
- name: str
+class BitCat(enum.Enum):
+ UNUSED = 1
+ USED = 2
+ RESERVED = 3
+ SUBFIELD = 4
+
+
+class Bit:
+ bitname: str
in_versions: set[str]
+ num: int
+ cat: BitCat
- val: str
+ def __init__(self, num: int) -> None:
+ self.bitname = ""
+ self.in_versions = set()
+ self.num = num
+ self.cat = BitCat.UNUSED
- def __init__(self) -> None:
+
+class BitAlias:
+ bitname: str
+ in_versions: set[str]
+ val: str # FIXME: Don't have bitfield aliases be raw C expressions
+
+ def __init__(self, name: str, val: str) -> None:
+ self.bitname = name
self.in_versions = set()
+ self.val = val
class Bitfield:
- name: str
+ typname: str
in_versions: set[str]
-
prim: Primitive
+ bits: list[Bit]
+ names: dict[str, Bit | BitAlias]
- bits: list[str] # bitnames
- names: dict[str, BitfieldVal] # bits *and* aliases
-
- def __init__(self) -> None:
+ def __init__(self, name: str, prim: Primitive) -> None:
+ self.typname = name
self.in_versions = set()
+ self.prim = prim
+ self.bits = [Bit(i) for i in range(prim.static_size * 8)]
self.names = {}
@property
@@ -106,21 +129,6 @@ class Bitfield:
def max_size(self, version: str) -> int:
return self.static_size
- def bit_is_valid(self, bit: str | int, ver: str | None = None) -> bool:
- """Return whether the given bit is valid in the given protocol
- version.
-
- """
- bitname = self.bits[bit] if isinstance(bit, int) else bit
- assert bitname in self.bits
- if not bitname:
- return False
- if bitname.startswith("_"):
- return False
- if ver and (ver not in self.names[bitname].in_versions):
- return False
- return True
-
class ExprLit:
val: int
@@ -130,16 +138,16 @@ class ExprLit:
class ExprSym:
- name: str
+ symname: str
def __init__(self, name: str) -> None:
- self.name = name
+ self.symname = name
class ExprOp:
- op: Literal["-", "+"]
+ op: typing.Literal["-", "+"]
- def __init__(self, op: Literal["-", "+"]) -> None:
+ def __init__(self, op: typing.Literal["-", "+"]) -> None:
self.op = op
@@ -156,7 +164,7 @@ class Expr:
class StructMember:
# from left-to-right when parsing
cnt: "StructMember | None" = None
- name: str
+ membname: str
typ: "Type"
max: Expr
val: Expr
@@ -168,10 +176,10 @@ class StructMember:
assert self.cnt
if not isinstance(self.cnt.typ, Primitive):
raise ValueError(
- f"list count must be an integer type: {repr(self.cnt.name)}"
+ f"list count must be an integer type: {self.cnt.membname!r}"
)
if self.cnt.val: # TODO: allow this?
- raise ValueError(f"list count may not have ,val=: {repr(self.cnt.name)}")
+ raise ValueError(f"list count may not have ,val=: {self.cnt.membname!r}")
return 0
@property
@@ -179,26 +187,26 @@ class StructMember:
assert self.cnt
if not isinstance(self.cnt.typ, Primitive):
raise ValueError(
- f"list count must be an integer type: {repr(self.cnt.name)}"
+ f"list count must be an integer type: {self.cnt.membname!r}"
)
if self.cnt.val: # TODO: allow this?
- raise ValueError(f"list count may not have ,val=: {repr(self.cnt.name)}")
+ raise ValueError(f"list count may not have ,val=: {self.cnt.membname!r}")
if self.cnt.max:
# TODO: be more flexible?
if len(self.cnt.max.tokens) != 1:
raise ValueError(
- f"list count ,max= may only have 1 token: {repr(self.cnt.name)}"
+ f"list count ,max= may only have 1 token: {self.cnt.membname!r}"
)
match tok := self.cnt.max.tokens[0]:
case ExprLit():
return tok.val
- case ExprSym(name="s32_max"):
+ case ExprSym(symname="s32_max"):
return (1 << 31) - 1
- case ExprSym(name="s64_max"):
+ case ExprSym(symname="s64_max"):
return (1 << 63) - 1
case _:
raise ValueError(
- f'list count ,max= only allows literal, "s32_max", and "s64_max" tokens: {repr(self.cnt.name)}'
+ f'list count ,max= only allows literal, "s32_max", and "s64_max" tokens: {self.cnt.membname!r}'
)
return (1 << (self.cnt.typ.value * 8)) - 1
@@ -218,7 +226,7 @@ class StructMember:
class Struct:
- name: str
+ typname: str
in_versions: set[str]
members: list[StructMember]
@@ -257,7 +265,7 @@ class Message(Struct):
@property
def msgid(self) -> int:
assert len(self.members) >= 3
- assert self.members[1].name == "typ"
+ assert self.members[1].membname == "typ"
assert self.members[1].static_size == 1
assert self.members[1].val
assert len(self.members[1].val.tokens) == 1
@@ -266,12 +274,15 @@ class Message(Struct):
type Type = Primitive | Number | Bitfield | Struct | Message
-T = TypeVar("T", Number, Bitfield, Struct, Message)
+type UserType = Number | Bitfield | Struct | Message
+T = typing.TypeVar("T", Number, Bitfield, Struct, Message)
# Parse ########################################################################
re_priname = "(?:1|2|4|8)" # primitive names
re_symname = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" # "symbol" names; most *.9p-defined names
+re_symname_u = "(?:[A-Z_][A-Z_0-9]*)" # upper-case "symbol" names; bit names
+re_symname_l = "(?:[a-z_][a-z_0-9]*)" # lower-case "symbol" names; bit names
re_impname = r"(?:\*|" + re_symname + ")" # names we can import
re_msgname = r"(?:[TR][a-zA-Z_0-9]*)" # names a message can be
@@ -281,8 +292,18 @@ re_expr = f"(?:(?:-|\\+|[0-9]+|&?{re_symname})+)"
re_numspec = f"(?P<name>{re_symname})\\s*=\\s*(?P<val>\\S+)"
-re_bitspec_bit = f"(?P<bit>[0-9]+)\\s*=\\s*(?P<name>{re_symname})"
-re_bitspec_alias = f"(?P<name>{re_symname})\\s*=\\s*(?P<val>\\S+)"
+re_bitspec_bit = (
+ "(?P<bitnum>[0-9]+)\\s*=\\s*(?:"
+ + "|".join(
+ [
+ f"(?P<name_used>{re_symname_u})",
+ f"reserved\\((?P<name_reserved>{re_symname_u})\\)",
+ f"subfield\\((?P<name_subfield>{re_symname_l})\\)",
+ ]
+ )
+ + ")"
+)
+re_bitspec_alias = f"(?P<name>{re_symname_u})\\s*=\\s*(?P<val>\\S+)"
re_memberspec = f"(?:(?P<cnt>{re_symname})\\*\\()?(?P<name>{re_symname})\\[(?P<typ>{re_memtype})(?:,max=(?P<max>{re_expr})|,val=(?P<val>{re_expr}))*\\]\\)?"
@@ -294,55 +315,67 @@ def parse_numspec(ver: str, n: Number, spec: str) -> None:
name = m.group("name")
val = m.group("val")
if name in n.vals:
- raise ValueError(f"{n.name}: name {repr(name)} already assigned")
+ raise ValueError(f"{n.typname}: name {name!r} already assigned")
n.vals[name] = val
else:
- raise SyntaxError(f"invalid num spec {repr(spec)}")
+ raise SyntaxError(f"invalid num spec {spec!r}")
def parse_bitspec(ver: str, bf: Bitfield, spec: str) -> None:
spec = spec.strip()
- bit: int | None
- val: BitfieldVal
if m := re.fullmatch(re_bitspec_bit, spec):
- bit = int(m.group("bit"))
- name = m.group("name")
-
- val = BitfieldVal()
- val.name = name
- val.val = f"1<<{bit}"
- val.in_versions.add(ver)
-
- if bit < 0 or bit >= len(bf.bits):
- raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds")
- if bf.bits[bit]:
- raise ValueError(f"{bf.name}: bit {bit} already assigned")
- bf.bits[bit] = val.name
+ bitnum = int(m.group("bitnum"))
+ if bitnum < 0 or bitnum >= len(bf.bits):
+ raise ValueError(f"{bf.typname}: bit num {bitnum} out-of-bounds")
+ bit = bf.bits[bitnum]
+ if bit.cat != BitCat.UNUSED:
+ raise ValueError(f"{bf.typname}: bit num {bitnum} already assigned")
+ if name := m.group("name_used"):
+ bit.bitname = name
+ bit.cat = BitCat.USED
+ bit.in_versions.add(ver)
+ elif name := m.group("name_reserved"):
+ bit.bitname = name
+ bit.cat = BitCat.RESERVED
+ bit.in_versions.add(ver)
+ elif name := m.group("name_subfield"):
+ bit.bitname = name
+ bit.cat = BitCat.SUBFIELD
+ bit.in_versions.add(ver)
+ if bit.bitname:
+ if bit.bitname in bf.names:
+ other = bf.names[bit.bitname]
+ if (
+ isinstance(other, Bit)
+ and other.cat == bit.cat
+ and bit.cat == BitCat.SUBFIELD
+ ):
+ return
+ raise ValueError(
+ f"{bf.typname}: bit name {bit.bitname!r} already assigned"
+ )
+ bf.names[bit.bitname] = bit
elif m := re.fullmatch(re_bitspec_alias, spec):
- name = m.group("name")
- valstr = m.group("val")
-
- val = BitfieldVal()
- val.name = name
- val.val = valstr
- val.in_versions.add(ver)
+ alias = BitAlias(m.group("name"), m.group("val"))
+ alias.in_versions.add(ver)
+ if alias.bitname in bf.names:
+ raise ValueError(
+ f"{bf.typname}: bit name {alias.bitname!r} already assigned"
+ )
+ bf.names[alias.bitname] = alias
else:
- raise SyntaxError(f"invalid bitfield spec {repr(spec)}")
-
- if val.name in bf.names:
- raise ValueError(f"{bf.name}: name {val.name} already assigned")
- bf.names[val.name] = val
+ raise SyntaxError(f"invalid bitfield spec {spec!r}")
def parse_expr(expr: str) -> Expr:
assert re.fullmatch(re_expr, expr)
ret = Expr()
for tok in re.split("([-+])", expr):
- if tok == "-" or tok == "+":
+ if tok in ("-", "+"):
# I, for the life of me, do not understand why I need this
- # cast() to keep mypy happy.
- ret.tokens += [ExprOp(cast(Literal["-", "+"], tok))]
+ # typing.cast() to keep mypy happy.
+ ret.tokens += [ExprOp(typing.cast(typing.Literal["-", "+"], tok))]
elif re.fullmatch("[0-9]+", tok):
ret.tokens += [ExprLit(int(tok))]
else:
@@ -354,22 +387,22 @@ def parse_members(ver: str, env: dict[str, Type], struct: Struct, specs: str) ->
for spec in specs.split():
m = re.fullmatch(re_memberspec, spec)
if not m:
- raise SyntaxError(f"invalid member spec {repr(spec)}")
+ raise SyntaxError(f"invalid member spec {spec!r}")
member = StructMember()
member.in_versions = {ver}
- member.name = m.group("name")
- if any(x.name == member.name for x in struct.members):
- raise ValueError(f"duplicate member name {repr(member.name)}")
+ member.membname = m.group("name")
+ if any(x.membname == member.membname for x in struct.members):
+ raise ValueError(f"duplicate member name {member.membname!r}")
if m.group("typ") not in env:
- raise NameError(f"Unknown type {repr(m.group('typ'))}")
+ raise NameError(f"Unknown type {m.group('typ')!r}")
member.typ = env[m.group("typ")]
if cnt := m.group("cnt"):
- if len(struct.members) == 0 or struct.members[-1].name != cnt:
- raise ValueError(f"list count must be previous item: {repr(cnt)}")
+ if len(struct.members) == 0 or struct.members[-1].membname != cnt:
+ raise ValueError(f"list count must be previous item: {cnt!r}")
cnt_mem = struct.members[-1]
member.cnt = cnt_mem
_ = member.max_cnt # force validation
@@ -412,8 +445,8 @@ re_line_cont = f"\\s+{re_string('specs')}" # could be bitfield/struct/msg
def parse_file(
- filename: str, get_include: Callable[[str], tuple[str, list[Type]]]
-) -> tuple[str, list[Type]]:
+ filename: str, get_include: typing.Callable[[str], tuple[str, list[UserType]]]
+) -> tuple[str, list[UserType]]:
version: str | None = None
env: dict[str, Type] = {
"1": Primitive.u8,
@@ -425,13 +458,13 @@ def parse_file(
def get_type(name: str, tc: type[T]) -> T:
nonlocal env
if name not in env:
- raise NameError(f"Unknown type {repr(name)}")
+ raise NameError(f"Unknown type {name!r}")
ret = env[name]
if (not isinstance(ret, tc)) or (ret.__class__.__name__ != tc.__name__):
- raise NameError(f"Type {repr(ret.name)} is not a {tc.__name__}")
+ raise NameError(f"Type {ret.typname!r} is not a {tc.__name__}")
return ret
- with open(filename, "r") as fh:
+ with open(filename, "r", encoding="utf-8") as fh:
prev: Type | None = None
for lineno, line in enumerate(fh):
try:
@@ -452,7 +485,7 @@ def parse_file(
symname = symname.strip()
found = False
for typ in other_typs:
- if typ.name == symname or symname == "*":
+ if symname in (typ.typname, "*"):
found = True
match typ:
case Primitive():
@@ -461,6 +494,9 @@ def parse_file(
typ.in_versions.add(version)
case Bitfield():
typ.in_versions.add(version)
+ for bit in typ.bits:
+ if other_version in bit.in_versions:
+ bit.in_versions.add(version)
for val in typ.names.values():
if other_version in val.in_versions:
val.in_versions.add(version)
@@ -469,42 +505,38 @@ def parse_file(
for member in typ.members:
if other_version in member.in_versions:
member.in_versions.add(version)
- if typ.name in env and env[typ.name] != typ:
+ if typ.typname in env and env[typ.typname] != typ:
raise ValueError(
- f"duplicate type name {repr(typ.name)}"
+ f"duplicate type name {typ.typname!r}"
)
- env[typ.name] = typ
+ env[typ.typname] = typ
if symname != "*" and not found:
raise ValueError(
- f"import: {m.group('file')}: no symbol {repr(symname)}"
+ f"import: {m.group('file')}: no symbol {symname!r}"
)
elif m := re.fullmatch(re_line_num, line):
num = Number()
- num.name = m.group("name")
+ num.typname = m.group("name")
num.in_versions.add(version)
prim = env[m.group("prim")]
assert isinstance(prim, Primitive)
num.prim = prim
- if num.name in env:
- raise ValueError(f"duplicate type name {repr(num.name)}")
- env[num.name] = num
+ if num.typname in env:
+ raise ValueError(f"duplicate type name {num.typname!r}")
+ env[num.typname] = num
prev = num
elif m := re.fullmatch(re_line_bitfield, line):
- bf = Bitfield()
- bf.name = m.group("name")
- bf.in_versions.add(version)
-
prim = env[m.group("prim")]
assert isinstance(prim, Primitive)
- bf.prim = prim
- bf.bits = (prim.static_size * 8) * [""]
+ bf = Bitfield(m.group("name"), prim)
+ bf.in_versions.add(version)
- if bf.name in env:
- raise ValueError(f"duplicate type name {repr(bf.name)}")
- env[bf.name] = bf
+ if bf.typname in env:
+ raise ValueError(f"duplicate type name {bf.typname!r}")
+ env[bf.typname] = bf
prev = bf
elif m := re.fullmatch(re_line_bitfield_, line):
bf = get_type(m.group("name"), Bitfield)
@@ -515,16 +547,16 @@ def parse_file(
match m.group("op"):
case "=":
struct = Struct()
- struct.name = m.group("name")
+ struct.typname = m.group("name")
struct.in_versions.add(version)
struct.members = []
parse_members(version, env, struct, m.group("members"))
- if struct.name in env:
+ if struct.typname in env:
raise ValueError(
- f"duplicate type name {repr(struct.name)}"
+ f"duplicate type name {struct.typname!r}"
)
- env[struct.name] = struct
+ env[struct.typname] = struct
prev = struct
case "+=":
struct = get_type(m.group("name"), Struct)
@@ -535,16 +567,14 @@ def parse_file(
match m.group("op"):
case "=":
msg = Message()
- msg.name = m.group("name")
+ msg.typname = m.group("name")
msg.in_versions.add(version)
msg.members = []
parse_members(version, env, msg, m.group("members"))
- if msg.name in env:
- raise ValueError(
- f"duplicate type name {repr(msg.name)}"
- )
- env[msg.name] = msg
+ if msg.typname in env:
+ raise ValueError(f"duplicate type name {msg.typname!r}")
+ env[msg.typname] = msg
prev = msg
case "+=":
msg = get_type(m.group("name"), Message)
@@ -574,22 +604,27 @@ def parse_file(
if not version:
raise SyntaxError("must have exactly 1 version line")
- typs: list[Type] = [x for x in env.values() if not isinstance(x, Primitive)]
+ typs: list[UserType] = [x for x in env.values() if not isinstance(x, Primitive)]
for typ in [typ for typ in typs if isinstance(typ, Struct)]:
- valid_syms = ["end", "s32_max", "s64_max", *["&" + m.name for m in typ.members]]
+ valid_syms = [
+ "end",
+ "s32_max",
+ "s64_max",
+ *["&" + m.membname for m in typ.members],
+ ]
for member in typ.members:
if (
not isinstance(member.typ, Primitive)
and member.typ.in_versions < member.in_versions
):
raise ValueError(
- f"{typ.name}.{member.name}: type {member.typ.name} does not exist in {member.in_versions.difference(member.typ.in_versions)}"
+ f"{typ.typname}.{member.membname}: type {member.typ.typname} does not exist in {member.in_versions.difference(member.typ.in_versions)}"
)
for tok in [*member.max.tokens, *member.val.tokens]:
- if isinstance(tok, ExprSym) and tok.name not in valid_syms:
+ if isinstance(tok, ExprSym) and tok.symname not in valid_syms:
raise ValueError(
- f"{typ.name}.{member.name}: invalid sym: {tok.name}"
+ f"{typ.typname}.{member.membname}: invalid sym: {tok.symname}"
)
return version, typs
@@ -599,35 +634,35 @@ def parse_file(
class Parser:
- cache: dict[str, tuple[str, list[Type]]] = {}
+ cache: dict[str, tuple[str, list[UserType]]] = {}
- def parse_file(self, filename: str) -> tuple[str, list[Type]]:
+ def parse_file(self, filename: str) -> tuple[str, list[UserType]]:
filename = os.path.normpath(filename)
if filename not in self.cache:
- def get_include(other_filename: str) -> tuple[str, list[Type]]:
+ def get_include(other_filename: str) -> tuple[str, list[UserType]]:
return self.parse_file(os.path.join(filename, "..", other_filename))
self.cache[filename] = parse_file(filename, get_include)
return self.cache[filename]
- def all(self) -> tuple[set[str], list[Type]]:
+ def all(self) -> tuple[set[str], list[UserType]]:
ret_versions: set[str] = set()
- ret_typs: dict[str, Type] = {}
+ ret_typs: dict[str, UserType] = {}
for version, typs in self.cache.values():
if version in ret_versions:
- raise ValueError(f"duplicate protocol version {repr(version)}")
+ raise ValueError(f"duplicate protocol version {version!r}")
ret_versions.add(version)
for typ in typs:
- if typ.name in ret_typs:
- if typ != ret_typs[typ.name]:
- raise ValueError(f"duplicate type name {repr(typ.name)}")
+ if typ.typname in ret_typs:
+ if typ != ret_typs[typ.typname]:
+ raise ValueError(f"duplicate type name {typ.typname!r}")
else:
- ret_typs[typ.name] = typ
+ ret_typs[typ.typname] = typ
msgids: set[int] = set()
for typ in ret_typs.values():
if isinstance(typ, Message):
if typ.msgid in msgids:
- raise ValueError(f"duplicate msgid {repr(typ.msgid)}")
+ raise ValueError(f"duplicate msgid {typ.msgid!r}")
msgids.add(typ.msgid)
return ret_versions, list(ret_typs.values())