import abc
import enum
import struct
import typing
from struct import Struct
from collections import deque
from .exceptions import ParseError
from . import (
read_raw_struct,
read,
read_until,
read_struct,
Parser,
get_parser,
wait,
read_int,
)
_parent_stack = deque()
_mapping_stack = deque()
[docs]class Unit(abc.ABC):
"""Unit is the base class of all units. \
If you can build your own unit class, you must inherit from it"""
[docs] @abc.abstractmethod
def get_value(self) -> typing.Generator:
"get object you want from bytes"
@abc.abstractmethod
def __call__(self, obj: typing.Any) -> bytes:
"convert user-given object to bytes"
[docs] def parse(self, data: bytes, *, strict=True):
"a convenient function to help you parse fixed bytes"
return Parser(self.get_value()).parse(data, strict=strict)
[docs]class BinarySchema(metaclass=BinarySchemaMetaclass):
"The main class for users to define their own binary structures"
def __init__(self, *args):
if len(args) != len(self.__class__._fields):
raise ValueError(
f"need {len(self.__class__._fields)} args, got {len(args)}"
)
_parent_stack.append(self)
self.bins = {}
for arg, (name, field) in zip(args, self.__class__._fields.items()):
if isinstance(field, BinarySchemaMetaclass):
binary = arg.binary
elif isinstance(field, Unit):
binary = field(arg)
if arg is ...:
arg = field.get_default()
setattr(self, name, arg)
self.bins[name] = binary
self.binary = b"".join(self.bins.values())
_parent_stack.pop()
def __str__(self):
sl = []
for name in self.__class__._fields:
value = getattr(self, name)
sl.append(f"{name}={value!r}")
s = ", ".join(sl)
return f"{self.__class__.__name__}({s})"
def __repr__(self):
return f"<{self}>"
def __eq__(self, other: "BinarySchema") -> bool:
if not isinstance(other, self.__class__):
return False
for name in self.__class__._fields:
if getattr(self, name) != getattr(other, name):
return False
return True
[docs] @classmethod
def get_value(cls) -> typing.Generator:
"get `BinarySchema` object from bytes"
mapping = {}
_mapping_stack.append(mapping)
try:
for name, field in cls._fields.items():
mapping[name] = yield from field.get_value()
except Exception:
raise ParseError(mapping)
_mapping_stack.pop()
return cls(*mapping.values())
[docs] @classmethod
def get_parser(cls) -> Parser:
return Parser(cls.get_value())
[docs] @classmethod
def parse(cls, data: bytes, *, strict=True) -> "BinarySchema":
return cls.get_parser().parse(data, strict=strict)
FieldType = typing.Union[BinarySchemaMetaclass, Unit]
[docs]class StructUnit(Unit):
def __init__(self, format_: str):
self._struct = Struct(format_)
def __str__(self):
return f"{self.__class__.__name__}({self._struct.format})"
[docs] def get_value(self):
return (yield from read_raw_struct(self._struct))[0]
def __call__(self, obj) -> bytes:
return self._struct.pack(obj)
[docs]class IntUnit(Unit):
def __init__(self, length: int, byteorder: str, signed: bool = False):
self.length = length
self.byteorder = byteorder
self.signed = signed
[docs] def get_value(self):
return (
yield from read_int(
self.length, byteorder=self.byteorder, signed=self.signed
)
)
def __call__(self, obj: int) -> bytes:
return obj.to_bytes(self.length, self.byteorder, signed=self.signed)
int8 = StructUnit("b")
uint8 = StructUnit("B")
int16 = StructUnit("h")
int16be = StructUnit(">h")
uint16 = StructUnit("H")
uint16be = StructUnit(">H")
int24 = IntUnit(3, "little", signed=True)
int24be = IntUnit(3, "big", signed=True)
uint24 = IntUnit(3, "little", signed=False)
uint24be = IntUnit(3, "big", signed=False)
int32 = StructUnit("i")
int32be = StructUnit(">i")
uint32 = StructUnit("I")
uint32be = StructUnit(">I")
int64 = StructUnit("q")
int64be = StructUnit(">q")
uint64 = StructUnit("Q")
uint64be = StructUnit(">Q")
float32 = StructUnit("f")
float32be = StructUnit(">f")
float64 = StructUnit("d")
float64be = StructUnit(">d")
[docs]class Bytes(Unit):
def __init__(self, length):
self.length = length
if length >= 0:
self._struct = Struct(f"{length}s")
def __str__(self):
return f"{self.__class__.__name__}({self.length})"
[docs] def get_value(self):
if self.length >= 0:
return (yield from read_raw_struct(self._struct))[0]
else:
return (yield from read())
def __call__(self, obj) -> bytes:
if self.length >= 0:
return self._struct.pack(obj)
else:
return obj
[docs]class String(Bytes):
def __init__(self, length, encoding="utf-8"):
super().__init__(length)
self.encoding = encoding
def __str__(self):
return f"{self.__class__.__name__}({self.length})"
[docs] def get_value(self):
v, = yield from read_raw_struct(self._struct)
return v.decode(self.encoding)
def __call__(self, obj: str) -> bytes:
return super().__call__(obj.encode(self.encoding))
[docs]class MustEqual(Unit):
def __init__(self, unit, value):
self.unit = unit
self.value = value
def __str__(self):
return f"{self.__class__.__name__}({self.unit}, {self.value})"
[docs] def get_value(self):
result = yield from self.unit.get_value()
if self.value != result:
raise ValueError(f"expect {self.value}, got {result}")
return result
[docs] def get_default(self):
return self.value
def __call__(self, obj) -> bytes:
if obj is not ...:
if self.value != obj:
raise ValueError(f"expect {self.value}, got {obj}")
return self.unit(self.value)
[docs]class EndWith(Unit):
def __init__(self, bytes_):
self.bytes_ = bytes_
def __str__(self):
return f"{self.__class__.__name__}({self.bytes_})"
[docs] def get_value(self):
return (yield from read_until(self.bytes_, return_tail=False))
def __call__(self, obj: bytes) -> bytes:
return obj + self.bytes_
[docs]class LengthPrefixedBytes(Unit):
def __init__(self, length_unit: StructUnit):
self.length_unit = length_unit
def __str__(self):
return f"{self.__class__.__name__}({self.length_unit})"
[docs] def get_value(self):
length = yield from self.length_unit.get_value()
return (yield from read_struct(f"{length}s"))[0]
def __call__(self, obj: bytes) -> bytes:
length = len(obj)
return self.length_unit(length) + struct.pack(f"{length}s", obj)
[docs]class LengthPrefixedString(Unit):
def __init__(self, length_unit: StructUnit, encoding="utf-8"):
self.length_unit = length_unit
self.encoding = encoding
def __str__(self):
return f"{self.__class__.__name__}({self.length_unit}, {self.encoding})"
[docs] def get_value(self):
length = yield from self.length_unit.get_value()
v, = yield from read_struct(f"{length}s")
return v.decode(self.encoding)
def __call__(self, obj: str) -> bytes:
length = len(obj)
return self.length_unit(length) + struct.pack(
f"{length}s", obj.encode(self.encoding)
)
[docs]class LengthPrefixedObjectList(Unit):
def __init__(self, length_unit: StructUnit, object_unit: FieldType):
self.length_unit = length_unit
self.object_unit = object_unit
def __str__(self):
return f"{self.__class__.__name__}({self.length_unit}, {self.object_unit})"
[docs] def get_value(self):
length = yield from self.length_unit.get_value()
data, = yield from read_struct(f"{length}s")
parser = Parser(self._gen())
return parser.parse(data)
def _gen(self):
parser = yield from get_parser()
lst = []
yield from wait()
while parser.has_more_data():
lst.append((yield from self.object_unit.get_value()))
return lst
def __call__(self, obj_list: typing.List[FieldType]) -> bytes:
if isinstance(self.object_unit, BinarySchemaMetaclass):
bytes_ = b"".join(bs.binary for bs in obj_list)
elif isinstance(self.object_unit, Unit):
bytes_ = b"".join(self.object_unit(bs) for bs in obj_list)
return self.length_unit(len(bytes_)) + bytes_
[docs]class LengthPrefixedObject(LengthPrefixedObjectList):
def _gen(self):
parser = yield from get_parser()
v = yield from self.object_unit.get_value()
if parser.has_more_data():
raise ValueError("extra bytes left")
return v
def __call__(self, obj: FieldType) -> bytes:
bytes_ = (
obj.binary
if isinstance(self.object_unit, BinarySchemaMetaclass)
else self.object_unit(obj)
)
return self.length_unit(len(bytes_)) + bytes_
[docs]class Switch(Unit):
def __init__(self, ref: str, cases: typing.Mapping[typing.Any, FieldType]):
self.ref = ref
self.cases = cases
def __str__(self):
return f"{self.__class__.__name__}({self.ref}, {self.cases})"
[docs] def get_value(self):
mapping = _mapping_stack[-1]
unit = self.cases[mapping[self.ref]]
return (yield from unit.get_value())
def __call__(self, obj) -> bytes:
parent = _parent_stack[-1]
real_field = self.cases[getattr(parent, self.ref)]
return real_field(obj) if isinstance(real_field, Unit) else obj.binary
[docs]class SizedIntEnum(Unit):
def __init__(self, size_unit: StructUnit, enum_class):
self.size_unit = size_unit
self.enum_class = enum_class
def __str__(self):
return f"{self.__class__.__name__}({self.size_unit}, {self.enum_class})"
[docs] def get_value(self):
v = yield from self.size_unit.get_value()
return self.enum_class(v)
def __call__(self, obj: enum.IntEnum) -> bytes:
return self.size_unit(obj.value)
[docs]class Convert(Unit):
def __init__(
self, unit: FieldType, *, encode: typing.Callable, decode: typing.Callable
):
self.unit = unit
self.encode = encode
self.decode = decode
[docs] def get_value(self):
v = yield from self.unit.get_value()
return self.decode(v)
def __call__(self, obj) -> bytes:
return self.unit(self.encode(obj))
[docs]def Group(**fields: typing.Dict[str, FieldType]) -> BinarySchema:
return type("Group", (BinarySchema,), fields)