Source code for iofree

"""`iofree` is an easy-to-use and powerful library \
to help you implement network protocols and binary parsers."""
import sys
import typing
from collections import deque
from enum import IntEnum, auto
from socket import SocketType
from struct import Struct

from .exceptions import NoResult, ParseError

__version__ = "0.2.4"
_wait = object()
_no_result = object()


[docs]class Traps(IntEnum): _read = auto() _read_more = auto() _read_until = auto() _read_struct = auto() _read_int = auto() _wait = auto() _peek = auto() _wait_event = auto() _get_parser = auto()
[docs]class State(IntEnum): _state_wait = auto() _state_next = auto() _state_end = auto()
[docs]class Parser: def __init__(self, gen: typing.Generator): self.gen = gen self._input = bytearray() self._input_events: typing.Deque = deque() self._output_events: typing.Deque = deque() self._res = _no_result self._mapping_stack: typing.Deque = deque() self._next_value = None self._last_trap: typing.Optional[tuple] = None self._pos = 0 self._state: State = State._state_wait self._process() def __repr__(self): return f"<{self.__class__.__qualname__}({self.gen})>" def __iter__(self): return self def __next__(self) -> typing.Any: if self._output_events: return self._output_events.popleft() raise StopIteration
[docs] def parse(self, data: bytes, *, strict: bool = True) -> typing.Any: """ parse bytes """ self.send(data) if strict and self.has_more_data(): raise ParseError("redundant data left") return self.get_result()
[docs] def send(self, data: bytes = b"") -> None: """ send data for parsing """ self._input.extend(data) self._process()
[docs] def read_output_bytes(self) -> bytes: buf = [] for to_send, close, exc, result in self: buf.append(result) return b"".join(buf)
[docs] def respond( self, *, data: bytes = b"", close: bool = False, exc: typing.Optional[Exception] = None, result: typing.Any = _no_result, ) -> None: """produce some event data to interact with a stream: data: bytes to send to the peer close: whether the socket should be closed exc: raise an exception to break the loop result: result to return """ self._output_events.append((data, close, exc, result))
[docs] def run(self, sock: SocketType) -> typing.Any: "reference implementation of how to deal with socket" self.send(b"") while True: for to_send, close, exc, result in self: if to_send: sock.sendall(to_send) if close: sock.close() if exc: raise exc if result is not _no_result: return result data = sock.recv(1024) if not data: raise ParseError("need data") self.send(data)
@property def has_result(self) -> bool: return self._res is not _no_result
[docs] def get_result(self) -> typing.Any: """ raises *NoResult* exception if no result has been set """ self._process() if not self.has_result: raise NoResult("no result") return self._res
[docs] def set_result(self, result) -> None: self._res = result self.respond(result=result)
[docs] def finished(self) -> bool: return self._state is State._state_end
def _process(self) -> None: if self._state is State._state_end: return self._state = State._state_next while self._state is State._state_next: self._next_state() def _next_state(self) -> None: if self._last_trap is None: try: trap, *args = self.gen.send(self._next_value) except StopIteration as e: self._state = State._state_end self.set_result(e.value) return except Exception: self._state = State._state_end tb = sys.exc_info()[2] raise ParseError(f"{self._next_value!r}").with_traceback(tb) else: if not isinstance(trap, Traps): self._state = State._state_end raise RuntimeError(f"Expect Traps object, but got: {trap}") else: trap, *args = self._last_trap result = getattr(self, trap.name)(*args) if result is _wait: self._state = State._state_wait self._last_trap = (trap, *args) else: self._state = State._state_next self._next_value = result self._last_trap = None
[docs] def readall(self) -> bytes: """ retrieve data from input back """ return self._read(0)
[docs] def has_more_data(self) -> bool: "indicate whether input has some bytes left" return len(self._input) > 0
[docs] def send_event(self, event: typing.Any) -> None: self._input_events.append(event) self._process()
def _wait_event(self): if self._input_events: return self._input_events.popleft() return _wait def _wait(self) -> typing.Optional[object]: if not getattr(self, "_waiting", False): self._waiting = True return _wait self._waiting = False return None def _read(self, nbytes: int = 0, from_=None) -> bytes: buf = self._input if from_ is None else from_ if nbytes == 0: data = bytes(buf) del buf[:] return data if len(buf) < nbytes: return _wait data = bytes(buf[:nbytes]) del buf[:nbytes] return data def _read_more(self, nbytes: int = 1, from_=None) -> typing.Union[object, bytes]: buf = self._input if from_ is None else from_ if len(buf) < nbytes: return _wait data = bytes(buf) del buf[:] return data def _read_until( self, data: bytes, return_tail: bool = True, from_=None ) -> typing.Union[object, bytes]: buf = self._input if from_ is None else from_ index = buf.find(data, self._pos) if index == -1: self._pos = len(buf) - len(data) + 1 self._pos = self._pos if self._pos > 0 else 0 return _wait size = index + len(data) if return_tail: data = bytes(buf[:size]) else: data = bytes(buf[:index]) del buf[:size] self._pos = 0 return data def _read_struct( self, struct_obj: Struct, from_=None ) -> typing.Union[object, tuple]: buf = self._input if from_ is None else from_ size = struct_obj.size if len(buf) < size: return _wait result = struct_obj.unpack_from(buf) del buf[:size] return result def _read_int( self, nbytes: int, byteorder: str = "big", signed: bool = False, from_=None ) -> typing.Union[object, int]: buf = self._input if from_ is None else from_ if len(buf) < nbytes: return _wait data = self._read(nbytes) return int.from_bytes(data, byteorder, signed=signed) def _peek(self, nbytes: int = 1, from_=None) -> typing.Union[object, bytes]: buf = self._input if from_ is None else from_ if len(buf) < nbytes: return _wait return bytes(buf[:nbytes]) def _get_parser(self) -> "Parser": return self
[docs]class LinkedNode: __slots__ = ("parser", "next") def __init__(self, parser: Parser, next_: typing.Optional["LinkedNode"]): self.parser = parser self.next = next_
[docs]class ParserChain: def __init__(self, *parsers: Parser): nxt = None for parser in reversed(parsers): node = LinkedNode(parser, nxt) nxt = node self.first = node
[docs] def send(self, data: bytes) -> None: self.first.parser.send(data)
def __iter__(self): return self._get_events(self.first) def _get_events( self, node: LinkedNode ) -> typing.Generator[ typing.Tuple[ typing.Optional[bytes], typing.Optional[bool], typing.Optional[Exception], typing.Any, ], None, None, ]: for data, close, exc, result in node.parser: if result is not _no_result and node.next: node.next.parser.send(result) yield (data, close, exc, _no_result) else: yield (data, close, exc, result) if node.next: yield from self._get_events(node.next)
[docs]def read(nbytes: int = 0, *, from_=None) -> typing.Generator[tuple, bytes, bytes]: """ if nbytes = 0, read as many as possible, empty bytes is valid; if nbytes > 0, read *exactly* ``nbytes`` """ return (yield (Traps._read, nbytes, from_))
[docs]def read_more(nbytes: int = 1, *, from_=None) -> typing.Generator[tuple, bytes, bytes]: """ read *at least* ``nbytes`` """ if nbytes <= 0: raise ValueError(f"nbytes must > 0, but got {nbytes}") return (yield (Traps._read_more, nbytes, from_))
[docs]def read_until( data: bytes, *, return_tail: bool = True, from_=None ) -> typing.Generator[tuple, bytes, bytes]: """ read until some bytes appear """ return (yield (Traps._read_until, data, return_tail, from_))
[docs]def read_struct(fmt: str, *, from_=None) -> typing.Generator[tuple, tuple, tuple]: """ read specific formatted data """ return (yield (Traps._read_struct, Struct(fmt), from_))
[docs]def read_raw_struct( struct_obj: Struct, *, from_=None ) -> typing.Generator[tuple, tuple, tuple]: """ read raw struct formatted data """ return (yield (Traps._read_struct, struct_obj, from_))
[docs]def read_int( nbytes: int, byteorder: str = "big", *, signed: bool = False, from_=None ) -> typing.Generator[tuple, int, int]: """ read some bytes as integer """ if nbytes <= 0: raise ValueError(f"nbytes must > 0, but got {nbytes}") return (yield (Traps._read_int, nbytes, byteorder, signed, from_))
[docs]def wait() -> typing.Generator[tuple, bytes, typing.Optional[object]]: """ wait for next send event """ return (yield (Traps._wait,))
[docs]def peek(nbytes: int = 1, *, from_=None) -> typing.Generator[tuple, bytes, bytes]: """ peek many bytes without taking them away from buffer """ if nbytes <= 0: raise ValueError(f"nbytes must > 0, but got {nbytes}") return (yield (Traps._peek, nbytes, from_))
[docs]def wait_event() -> typing.Generator[tuple, typing.Any, typing.Any]: """ wait for an event """ return (yield (Traps._wait_event,))
[docs]def get_parser() -> typing.Generator[tuple, Parser, Parser]: "get current parser object" return (yield (Traps._get_parser,))
[docs]def parser(generator_func: typing.Callable) -> typing.Callable: "decorator function to wrap a generator" def create_parser(*args, **kwargs) -> Parser: return Parser(generator_func(*args, **kwargs)) generator_func.parser = create_parser return generator_func