hat.util
Common utility functions
1"""Common utility functions""" 2 3from hat.util import cron 4from hat.util.bytes import (Bytes, 5 BytesBuffer) 6from hat.util.callback import (RegisterCallbackHandle, 7 ExceptionCb, 8 CallbackRegistry) 9from hat.util.first import first 10from hat.util.socket import (get_unused_tcp_port, 11 get_unused_udp_port) 12from hat.util.sqlite3 import (sqlite3_adapt_datetime, 13 sqlite3_convert_timestamp) 14 15 16__all__ = ['cron', 17 'Bytes', 18 'BytesBuffer', 19 'RegisterCallbackHandle', 20 'ExceptionCb', 21 'CallbackRegistry', 22 'first', 23 'get_unused_tcp_port', 24 'get_unused_udp_port', 25 'sqlite3_adapt_datetime', 26 'sqlite3_convert_timestamp']
9class BytesBuffer: 10 """Bytes buffer 11 12 All data added to BytesBuffer is considered immutable - it's content 13 (including size) should not be modified. 14 15 """ 16 17 def __init__(self): 18 self._data = collections.deque() 19 self._data_len = 0 20 21 def __len__(self): 22 return self._data_len 23 24 def add(self, data: Bytes): 25 """Add data""" 26 if not data: 27 return 28 29 self._data.append(data) 30 self._data_len += len(data) 31 32 def read(self, n: int = -1) -> Bytes: 33 """Read up to `n` bytes 34 35 If ``n < 0``, read all data. 36 37 """ 38 if n == 0: 39 return b'' 40 41 if n < 0 or n >= self._data_len: 42 data, self._data = self._data, collections.deque() 43 data_len, self._data_len = self._data_len, 0 44 45 else: 46 data = collections.deque() 47 data_len = 0 48 49 while data_len < n: 50 head = self._data.popleft() 51 self._data_len -= len(head) 52 53 if data_len + len(head) <= n: 54 data.append(head) 55 data_len += len(head) 56 57 else: 58 head = memoryview(head) 59 head1, head2 = head[:n-data_len], head[n-data_len:] 60 61 data.append(head1) 62 data_len += len(head1) 63 64 self._data.appendleft(head2) 65 self._data_len += len(head2) 66 67 if len(data) < 1: 68 return b'' 69 70 if len(data) < 2: 71 return data[0] 72 73 data_bytes = bytearray(data_len) 74 data_bytes_len = 0 75 76 while data: 77 head = data.popleft() 78 data_bytes[data_bytes_len:data_bytes_len+len(head)] = head 79 data_bytes_len += len(head) 80 81 return data_bytes 82 83 def clear(self) -> int: 84 """Clear data and return number of bytes cleared""" 85 self._data.clear() 86 data_len, self._data_len = self._data_len, 0 87 return data_len
Bytes buffer
All data added to BytesBuffer is considered immutable - it's content (including size) should not be modified.
24 def add(self, data: Bytes): 25 """Add data""" 26 if not data: 27 return 28 29 self._data.append(data) 30 self._data_len += len(data)
Add data
32 def read(self, n: int = -1) -> Bytes: 33 """Read up to `n` bytes 34 35 If ``n < 0``, read all data. 36 37 """ 38 if n == 0: 39 return b'' 40 41 if n < 0 or n >= self._data_len: 42 data, self._data = self._data, collections.deque() 43 data_len, self._data_len = self._data_len, 0 44 45 else: 46 data = collections.deque() 47 data_len = 0 48 49 while data_len < n: 50 head = self._data.popleft() 51 self._data_len -= len(head) 52 53 if data_len + len(head) <= n: 54 data.append(head) 55 data_len += len(head) 56 57 else: 58 head = memoryview(head) 59 head1, head2 = head[:n-data_len], head[n-data_len:] 60 61 data.append(head1) 62 data_len += len(head1) 63 64 self._data.appendleft(head2) 65 self._data_len += len(head2) 66 67 if len(data) < 1: 68 return b'' 69 70 if len(data) < 2: 71 return data[0] 72 73 data_bytes = bytearray(data_len) 74 data_bytes_len = 0 75 76 while data: 77 head = data.popleft() 78 data_bytes[data_bytes_len:data_bytes_len+len(head)] = head 79 data_bytes_len += len(head) 80 81 return data_bytes
Read up to n bytes
If n < 0, read all data.
6class RegisterCallbackHandle(typing.NamedTuple): 7 """Handle for canceling callback registration.""" 8 9 cancel: Callable[[], None] 10 """cancel callback registration""" 11 12 def __enter__(self): 13 return self 14 15 def __exit__(self, *args): 16 self.cancel()
Handle for canceling callback registration.
23class CallbackRegistry: 24 """Registry that enables callback registration and notification. 25 26 Callbacks in the registry are notified sequentially with 27 `CallbackRegistry.notify`. If a callback raises an exception, the 28 exception is caught and `exception_cb` handler is called. Notification of 29 subsequent callbacks is not interrupted. If handler is `None`, the 30 exception is reraised and no subsequent callback is notified. 31 32 Example:: 33 34 x = [] 35 y = [] 36 registry = CallbackRegistry() 37 38 registry.register(x.append) 39 registry.notify(1) 40 41 with registry.register(y.append): 42 registry.notify(2) 43 44 registry.notify(3) 45 46 assert x == [1, 2, 3] 47 assert y == [2] 48 49 """ 50 51 def __init__(self, 52 exception_cb: ExceptionCb | None = None): 53 self._exception_cb = exception_cb 54 self._cbs = [] # type: list[Callable] 55 56 def register(self, 57 cb: Callable 58 ) -> RegisterCallbackHandle: 59 """Register a callback.""" 60 self._cbs.append(cb) 61 return RegisterCallbackHandle(lambda: self._cbs.remove(cb)) 62 63 def notify(self, *args, **kwargs): 64 """Notify all registered callbacks.""" 65 for cb in self._cbs: 66 try: 67 cb(*args, **kwargs) 68 except Exception as e: 69 if self._exception_cb: 70 self._exception_cb(e) 71 else: 72 raise
Registry that enables callback registration and notification.
Callbacks in the registry are notified sequentially with
CallbackRegistry.notify. If a callback raises an exception, the
exception is caught and exception_cb handler is called. Notification of
subsequent callbacks is not interrupted. If handler is None, the
exception is reraised and no subsequent callback is notified.
Example::
x = []
y = []
registry = CallbackRegistry()
registry.register(x.append)
registry.notify(1)
with registry.register(y.append):
registry.notify(2)
registry.notify(3)
assert x == [1, 2, 3]
assert y == [2]
9def first(xs: Iterable[T], 10 fn: Callable[[T], typing.Any] = lambda _: True, 11 default: T | None = None 12 ) -> T | None: 13 """Return the first element from iterable that satisfies predicate `fn`, 14 or `default` if no such element exists. 15 16 Result of predicate `fn` can be of any type. Predicate is satisfied if it's 17 return value is truthy. 18 19 Args: 20 xs: collection 21 fn: predicate 22 default: default value 23 24 Example:: 25 26 assert first(range(3)) == 0 27 assert first(range(3), lambda x: x > 1) == 2 28 assert first(range(3), lambda x: x > 2) is None 29 assert first(range(3), lambda x: x > 2, 123) == 123 30 assert first({1: 'a', 2: 'b', 3: 'c'}) == 1 31 assert first([], default=123) == 123 32 33 """ 34 return next((i for i in xs if fn(i)), default)
Return the first element from iterable that satisfies predicate fn,
or default if no such element exists.
Result of predicate fn can be of any type. Predicate is satisfied if it's
return value is truthy.
Arguments:
- xs: collection
- fn: predicate
- default: default value
Example::
assert first(range(3)) == 0
assert first(range(3), lambda x: x > 1) == 2
assert first(range(3), lambda x: x > 2) is None
assert first(range(3), lambda x: x > 2, 123) == 123
assert first({1: 'a', 2: 'b', 3: 'c'}) == 1
assert first([], default=123) == 123
6def get_unused_tcp_port(host: str = '127.0.0.1') -> int: 7 """Search for unused TCP port""" 8 with contextlib.closing(socket.socket()) as sock: 9 sock.bind((host, 0)) 10 return sock.getsockname()[1]
Search for unused TCP port
13def get_unused_udp_port(host: str = '127.0.0.1') -> int: 14 """Search for unused UDP port""" 15 with contextlib.closing(socket.socket(type=socket.SOCK_DGRAM)) as sock: 16 sock.bind((host, 0)) 17 return sock.getsockname()[1]
Search for unused UDP port
5def sqlite3_adapt_datetime(val: datetime.datetime) -> str: 6 """SQLite3 datetime adapter 7 8 Adapter usage:: 9 10 sqlite3.register_adapter(datetime.datetime, sqlite3_adapt_datetime) 11 12 """ 13 return val.isoformat(" ")
SQLite3 datetime adapter
Adapter usage::
sqlite3.register_adapter(datetime.datetime, sqlite3_adapt_datetime)
16def sqlite3_convert_timestamp(val: bytes) -> datetime.datetime: 17 """SQLite3 timestamp converter 18 19 This converter is modification of standard library convertor taking into 20 account possible timezone info. 21 22 Converter usage:: 23 24 sqlite3.register_converter("timestamp", sqlite3_convert_timestamp) 25 26 """ 27 datepart, timetzpart = val.split(b" ") 28 if b"+" in timetzpart: 29 tzsign = 1 30 timepart, tzpart = timetzpart.split(b"+") 31 elif b"-" in timetzpart: 32 tzsign = -1 33 timepart, tzpart = timetzpart.split(b"-") 34 else: 35 timepart, tzpart = timetzpart, None 36 year, month, day = map(int, datepart.split(b"-")) 37 timepart_full = timepart.split(b".") 38 hours, minutes, seconds = map(int, timepart_full[0].split(b":")) 39 if len(timepart_full) == 2: 40 microseconds = int('{:0<6.6}'.format(timepart_full[1].decode())) 41 else: 42 microseconds = 0 43 if tzpart: 44 tzhours, tzminutes = map(int, tzpart.split(b":")) 45 tz = datetime.timezone( 46 tzsign * datetime.timedelta(hours=tzhours, minutes=tzminutes)) 47 else: 48 tz = None 49 50 dt = datetime.datetime(year, month, day, hours, minutes, seconds, 51 microseconds, tz) 52 return dt
SQLite3 timestamp converter
This converter is modification of standard library convertor taking into account possible timezone info.
Converter usage::
sqlite3.register_converter("timestamp", sqlite3_convert_timestamp)