hat.util
Common utility functions
1"""Common utility functions""" 2 3from hat.util import cron 4from hat.util.bytes import (Bytes, 5 BytesBuffer) 6from hat.util.callback import (RegisterCallbackHandle, 7 ExceptionCb, 8 CallbackRegistry) 9from hat.util.first import first 10from hat.util.socket import (get_unused_tcp_port, 11 get_unused_udp_port) 12from hat.util.sqlite3 import register_sqlite3_timestamp_converter 13 14 15__all__ = ['cron', 16 'Bytes', 17 'BytesBuffer', 18 'RegisterCallbackHandle', 19 'ExceptionCb', 20 'CallbackRegistry', 21 'first', 22 'get_unused_tcp_port', 23 'get_unused_udp_port', 24 'register_sqlite3_timestamp_converter']
9class BytesBuffer: 10 """Bytes buffer 11 12 All data added to BytesBuffer is considered immutable - it's content 13 (including size) should not be modified. 14 15 """ 16 17 def __init__(self): 18 self._data = collections.deque() 19 self._data_len = 0 20 21 def __len__(self): 22 return self._data_len 23 24 def add(self, data: Bytes): 25 """Add data""" 26 if not data: 27 return 28 29 self._data.append(data) 30 self._data_len += len(data) 31 32 def read(self, n: int = -1) -> Bytes: 33 """Read up to `n` bytes 34 35 If ``n < 0``, read all data. 36 37 """ 38 if n == 0: 39 return b'' 40 41 if n < 0 or n >= self._data_len: 42 data, self._data = self._data, collections.deque() 43 data_len, self._data_len = self._data_len, 0 44 45 else: 46 data = collections.deque() 47 data_len = 0 48 49 while data_len < n: 50 head = self._data.popleft() 51 self._data_len -= len(head) 52 53 if data_len + len(head) <= n: 54 data.append(head) 55 data_len += len(head) 56 57 else: 58 head = memoryview(head) 59 head1, head2 = head[:n-data_len], head[n-data_len:] 60 61 data.append(head1) 62 data_len += len(head1) 63 64 self._data.appendleft(head2) 65 self._data_len += len(head2) 66 67 if len(data) < 1: 68 return b'' 69 70 if len(data) < 2: 71 return data[0] 72 73 data_bytes = bytearray(data_len) 74 data_bytes_len = 0 75 76 while data: 77 head = data.popleft() 78 data_bytes[data_bytes_len:data_bytes_len+len(head)] = head 79 data_bytes_len += len(head) 80 81 return data_bytes 82 83 def clear(self) -> int: 84 """Clear data and return number of bytes cleared""" 85 self._data.clear() 86 data_len, self._data_len = self._data_len, 0 87 return data_len
Bytes buffer
All data added to BytesBuffer is considered immutable - it's content (including size) should not be modified.
24 def add(self, data: Bytes): 25 """Add data""" 26 if not data: 27 return 28 29 self._data.append(data) 30 self._data_len += len(data)
Add data
32 def read(self, n: int = -1) -> Bytes: 33 """Read up to `n` bytes 34 35 If ``n < 0``, read all data. 36 37 """ 38 if n == 0: 39 return b'' 40 41 if n < 0 or n >= self._data_len: 42 data, self._data = self._data, collections.deque() 43 data_len, self._data_len = self._data_len, 0 44 45 else: 46 data = collections.deque() 47 data_len = 0 48 49 while data_len < n: 50 head = self._data.popleft() 51 self._data_len -= len(head) 52 53 if data_len + len(head) <= n: 54 data.append(head) 55 data_len += len(head) 56 57 else: 58 head = memoryview(head) 59 head1, head2 = head[:n-data_len], head[n-data_len:] 60 61 data.append(head1) 62 data_len += len(head1) 63 64 self._data.appendleft(head2) 65 self._data_len += len(head2) 66 67 if len(data) < 1: 68 return b'' 69 70 if len(data) < 2: 71 return data[0] 72 73 data_bytes = bytearray(data_len) 74 data_bytes_len = 0 75 76 while data: 77 head = data.popleft() 78 data_bytes[data_bytes_len:data_bytes_len+len(head)] = head 79 data_bytes_len += len(head) 80 81 return data_bytes
Read up to n
bytes
If n < 0
, read all data.
6class RegisterCallbackHandle(typing.NamedTuple): 7 """Handle for canceling callback registration.""" 8 9 cancel: Callable[[], None] 10 """cancel callback registration""" 11 12 def __enter__(self): 13 return self 14 15 def __exit__(self, *args): 16 self.cancel()
Handle for canceling callback registration.
23class CallbackRegistry: 24 """Registry that enables callback registration and notification. 25 26 Callbacks in the registry are notified sequentially with 27 `CallbackRegistry.notify`. If a callback raises an exception, the 28 exception is caught and `exception_cb` handler is called. Notification of 29 subsequent callbacks is not interrupted. If handler is `None`, the 30 exception is reraised and no subsequent callback is notified. 31 32 Example:: 33 34 x = [] 35 y = [] 36 registry = CallbackRegistry() 37 38 registry.register(x.append) 39 registry.notify(1) 40 41 with registry.register(y.append): 42 registry.notify(2) 43 44 registry.notify(3) 45 46 assert x == [1, 2, 3] 47 assert y == [2] 48 49 """ 50 51 def __init__(self, 52 exception_cb: ExceptionCb | None = None): 53 self._exception_cb = exception_cb 54 self._cbs = [] # type: list[Callable] 55 56 def register(self, 57 cb: Callable 58 ) -> RegisterCallbackHandle: 59 """Register a callback.""" 60 self._cbs.append(cb) 61 return RegisterCallbackHandle(lambda: self._cbs.remove(cb)) 62 63 def notify(self, *args, **kwargs): 64 """Notify all registered callbacks.""" 65 for cb in self._cbs: 66 try: 67 cb(*args, **kwargs) 68 except Exception as e: 69 if self._exception_cb: 70 self._exception_cb(e) 71 else: 72 raise
Registry that enables callback registration and notification.
Callbacks in the registry are notified sequentially with
CallbackRegistry.notify
. If a callback raises an exception, the
exception is caught and exception_cb
handler is called. Notification of
subsequent callbacks is not interrupted. If handler is None
, the
exception is reraised and no subsequent callback is notified.
Example::
x = []
y = []
registry = CallbackRegistry()
registry.register(x.append)
registry.notify(1)
with registry.register(y.append):
registry.notify(2)
registry.notify(3)
assert x == [1, 2, 3]
assert y == [2]
9def first(xs: Iterable[T], 10 fn: Callable[[T], typing.Any] = lambda _: True, 11 default: T | None = None 12 ) -> T | None: 13 """Return the first element from iterable that satisfies predicate `fn`, 14 or `default` if no such element exists. 15 16 Result of predicate `fn` can be of any type. Predicate is satisfied if it's 17 return value is truthy. 18 19 Args: 20 xs: collection 21 fn: predicate 22 default: default value 23 24 Example:: 25 26 assert first(range(3)) == 0 27 assert first(range(3), lambda x: x > 1) == 2 28 assert first(range(3), lambda x: x > 2) is None 29 assert first(range(3), lambda x: x > 2, 123) == 123 30 assert first({1: 'a', 2: 'b', 3: 'c'}) == 1 31 assert first([], default=123) == 123 32 33 """ 34 return next((i for i in xs if fn(i)), default)
Return the first element from iterable that satisfies predicate fn
,
or default
if no such element exists.
Result of predicate fn
can be of any type. Predicate is satisfied if it's
return value is truthy.
Arguments:
- xs: collection
- fn: predicate
- default: default value
Example::
assert first(range(3)) == 0
assert first(range(3), lambda x: x > 1) == 2
assert first(range(3), lambda x: x > 2) is None
assert first(range(3), lambda x: x > 2, 123) == 123
assert first({1: 'a', 2: 'b', 3: 'c'}) == 1
assert first([], default=123) == 123
6def get_unused_tcp_port(host: str = '127.0.0.1') -> int: 7 """Search for unused TCP port""" 8 with contextlib.closing(socket.socket()) as sock: 9 sock.bind((host, 0)) 10 return sock.getsockname()[1]
Search for unused TCP port
13def get_unused_udp_port(host: str = '127.0.0.1') -> int: 14 """Search for unused UDP port""" 15 with contextlib.closing(socket.socket(type=socket.SOCK_DGRAM)) as sock: 16 sock.bind((host, 0)) 17 return sock.getsockname()[1]
Search for unused UDP port
6def register_sqlite3_timestamp_converter(): 7 """Register modified timestamp converter 8 9 This converter is modification of standard library convertor taking into 10 account possible timezone info. 11 12 """ 13 14 def convert_timestamp(val: bytes) -> datetime.datetime: 15 datepart, timetzpart = val.split(b" ") 16 if b"+" in timetzpart: 17 tzsign = 1 18 timepart, tzpart = timetzpart.split(b"+") 19 elif b"-" in timetzpart: 20 tzsign = -1 21 timepart, tzpart = timetzpart.split(b"-") 22 else: 23 timepart, tzpart = timetzpart, None 24 year, month, day = map(int, datepart.split(b"-")) 25 timepart_full = timepart.split(b".") 26 hours, minutes, seconds = map(int, timepart_full[0].split(b":")) 27 if len(timepart_full) == 2: 28 microseconds = int('{:0<6.6}'.format(timepart_full[1].decode())) 29 else: 30 microseconds = 0 31 if tzpart: 32 tzhours, tzminutes = map(int, tzpart.split(b":")) 33 tz = datetime.timezone( 34 tzsign * datetime.timedelta(hours=tzhours, minutes=tzminutes)) 35 else: 36 tz = None 37 38 val = datetime.datetime(year, month, day, hours, minutes, seconds, 39 microseconds, tz) 40 return val 41 42 sqlite3.register_converter("timestamp", convert_timestamp)
Register modified timestamp converter
This converter is modification of standard library convertor taking into account possible timezone info.