hat.util
Common utility functions
1"""Common utility functions""" 2 3import collections 4import contextlib 5import inspect 6import socket 7import typing 8import warnings 9 10 11T = typing.TypeVar('T') 12 13Bytes: typing.TypeAlias = bytes | bytearray | memoryview 14 15 16def register_type_alias(name: str): 17 """Register type alias 18 19 This function is temporary hack replacement for typing.TypeAlias. 20 21 It is expected that calling location will have `name` in local namespace 22 with type value. This function will wrap that type inside `typing.TypeVar` 23 and update annotations. 24 25 """ 26 warnings.warn("use typing.TypeAlias", DeprecationWarning) 27 frame = inspect.stack()[1][0] 28 f_locals = frame.f_locals 29 t = f_locals[name] 30 f_locals[name] = typing.TypeVar(name, t, t) 31 f_locals.setdefault('__annotations__', {})[name] = typing.Type[t] 32 33 34def first(xs: typing.Iterable[T], 35 fn: typing.Callable[[T], typing.Any] = lambda _: True, 36 default: T | None = None 37 ) -> T | None: 38 """Return the first element from iterable that satisfies predicate `fn`, 39 or `default` if no such element exists. 40 41 Result of predicate `fn` can be of any type. Predicate is satisfied if it's 42 return value is truthy. 43 44 Args: 45 xs: collection 46 fn: predicate 47 default: default value 48 49 Example:: 50 51 assert first(range(3)) == 0 52 assert first(range(3), lambda x: x > 1) == 2 53 assert first(range(3), lambda x: x > 2) is None 54 assert first(range(3), lambda x: x > 2, 123) == 123 55 assert first({1: 'a', 2: 'b', 3: 'c'}) == 1 56 assert first([], default=123) == 123 57 58 """ 59 return next((i for i in xs if fn(i)), default) 60 61 62class RegisterCallbackHandle(typing.NamedTuple): 63 """Handle for canceling callback registration.""" 64 65 cancel: typing.Callable[[], None] 66 """cancel callback registration""" 67 68 def __enter__(self): 69 return self 70 71 def __exit__(self, *args): 72 self.cancel() 73 74 75ExceptionCb: typing.TypeAlias = typing.Callable[[Exception], None] 76"""Exception callback""" 77 78 79class CallbackRegistry: 80 """Registry that enables callback registration and notification. 81 82 Callbacks in the registry are notified sequentially with 83 `CallbackRegistry.notify`. If a callback raises an exception, the 84 exception is caught and `exception_cb` handler is called. Notification of 85 subsequent callbacks is not interrupted. If handler is `None`, the 86 exception is reraised and no subsequent callback is notified. 87 88 Example:: 89 90 x = [] 91 y = [] 92 registry = CallbackRegistry() 93 94 registry.register(x.append) 95 registry.notify(1) 96 97 with registry.register(y.append): 98 registry.notify(2) 99 100 registry.notify(3) 101 102 assert x == [1, 2, 3] 103 assert y == [2] 104 105 """ 106 107 def __init__(self, 108 exception_cb: ExceptionCb | None = None): 109 self._exception_cb = exception_cb 110 self._cbs = [] # type: list[Callable] 111 112 def register(self, 113 cb: typing.Callable 114 ) -> RegisterCallbackHandle: 115 """Register a callback.""" 116 self._cbs.append(cb) 117 return RegisterCallbackHandle(lambda: self._cbs.remove(cb)) 118 119 def notify(self, *args, **kwargs): 120 """Notify all registered callbacks.""" 121 for cb in self._cbs: 122 try: 123 cb(*args, **kwargs) 124 except Exception as e: 125 if self._exception_cb: 126 self._exception_cb(e) 127 else: 128 raise 129 130 131def get_unused_tcp_port(host: str = '127.0.0.1') -> int: 132 """Search for unused TCP port""" 133 with contextlib.closing(socket.socket()) as sock: 134 sock.bind((host, 0)) 135 return sock.getsockname()[1] 136 137 138def get_unused_udp_port(host: str = '127.0.0.1') -> int: 139 """Search for unused UDP port""" 140 with contextlib.closing(socket.socket(type=socket.SOCK_DGRAM)) as sock: 141 sock.bind((host, 0)) 142 return sock.getsockname()[1] 143 144 145class BytesBuffer: 146 """Bytes buffer 147 148 All data added to BytesBuffer is considered immutable - it's content 149 (including size) should not be modified. 150 151 """ 152 153 def __init__(self): 154 self._data = collections.deque() 155 self._data_len = 0 156 157 def __len__(self): 158 return self._data_len 159 160 def add(self, data: Bytes): 161 """Add data""" 162 if not data: 163 return 164 165 self._data.append(data) 166 self._data_len += len(data) 167 168 def read(self, n: int = -1) -> Bytes: 169 """Read up to `n` bytes 170 171 If ``n < 0``, read all data. 172 173 """ 174 if n == 0: 175 return b'' 176 177 if n < 0 or n >= self._data_len: 178 data, self._data = self._data, collections.deque() 179 data_len, self._data_len = self._data_len, 0 180 181 else: 182 data = collections.deque() 183 data_len = 0 184 185 while data_len < n: 186 head = self._data.popleft() 187 self._data_len -= len(head) 188 189 if data_len + len(head) <= n: 190 data.append(head) 191 data_len += len(head) 192 193 else: 194 head = memoryview(head) 195 head1, head2 = head[:n-data_len], head[n-data_len:] 196 197 data.append(head1) 198 data_len += len(head1) 199 200 self._data.appendleft(head2) 201 self._data_len += len(head2) 202 203 if len(data) < 1: 204 return b'' 205 206 if len(data) < 2: 207 return data[0] 208 209 data_bytes = bytearray(data_len) 210 data_bytes_len = 0 211 212 while data: 213 head = data.popleft() 214 data_bytes[data_bytes_len:data_bytes_len+len(head)] = head 215 data_bytes_len += len(head) 216 217 return data_bytes 218 219 def clear(self) -> int: 220 """Clear data and return number of bytes cleared""" 221 self._data.clear() 222 data_len, self._data_len = self._data_len, 0 223 return data_len
17def register_type_alias(name: str): 18 """Register type alias 19 20 This function is temporary hack replacement for typing.TypeAlias. 21 22 It is expected that calling location will have `name` in local namespace 23 with type value. This function will wrap that type inside `typing.TypeVar` 24 and update annotations. 25 26 """ 27 warnings.warn("use typing.TypeAlias", DeprecationWarning) 28 frame = inspect.stack()[1][0] 29 f_locals = frame.f_locals 30 t = f_locals[name] 31 f_locals[name] = typing.TypeVar(name, t, t) 32 f_locals.setdefault('__annotations__', {})[name] = typing.Type[t]
Register type alias
This function is temporary hack replacement for typing.TypeAlias.
It is expected that calling location will have name
in local namespace
with type value. This function will wrap that type inside typing.TypeVar
and update annotations.
35def first(xs: typing.Iterable[T], 36 fn: typing.Callable[[T], typing.Any] = lambda _: True, 37 default: T | None = None 38 ) -> T | None: 39 """Return the first element from iterable that satisfies predicate `fn`, 40 or `default` if no such element exists. 41 42 Result of predicate `fn` can be of any type. Predicate is satisfied if it's 43 return value is truthy. 44 45 Args: 46 xs: collection 47 fn: predicate 48 default: default value 49 50 Example:: 51 52 assert first(range(3)) == 0 53 assert first(range(3), lambda x: x > 1) == 2 54 assert first(range(3), lambda x: x > 2) is None 55 assert first(range(3), lambda x: x > 2, 123) == 123 56 assert first({1: 'a', 2: 'b', 3: 'c'}) == 1 57 assert first([], default=123) == 123 58 59 """ 60 return next((i for i in xs if fn(i)), default)
Return the first element from iterable that satisfies predicate fn
,
or default
if no such element exists.
Result of predicate fn
can be of any type. Predicate is satisfied if it's
return value is truthy.
Arguments:
- xs: collection
- fn: predicate
- default: default value
Example::
assert first(range(3)) == 0
assert first(range(3), lambda x: x > 1) == 2
assert first(range(3), lambda x: x > 2) is None
assert first(range(3), lambda x: x > 2, 123) == 123
assert first({1: 'a', 2: 'b', 3: 'c'}) == 1
assert first([], default=123) == 123
63class RegisterCallbackHandle(typing.NamedTuple): 64 """Handle for canceling callback registration.""" 65 66 cancel: typing.Callable[[], None] 67 """cancel callback registration""" 68 69 def __enter__(self): 70 return self 71 72 def __exit__(self, *args): 73 self.cancel()
Handle for canceling callback registration.
Create new instance of RegisterCallbackHandle(cancel,)
Inherited Members
- builtins.tuple
- index
- count
Exception callback
80class CallbackRegistry: 81 """Registry that enables callback registration and notification. 82 83 Callbacks in the registry are notified sequentially with 84 `CallbackRegistry.notify`. If a callback raises an exception, the 85 exception is caught and `exception_cb` handler is called. Notification of 86 subsequent callbacks is not interrupted. If handler is `None`, the 87 exception is reraised and no subsequent callback is notified. 88 89 Example:: 90 91 x = [] 92 y = [] 93 registry = CallbackRegistry() 94 95 registry.register(x.append) 96 registry.notify(1) 97 98 with registry.register(y.append): 99 registry.notify(2) 100 101 registry.notify(3) 102 103 assert x == [1, 2, 3] 104 assert y == [2] 105 106 """ 107 108 def __init__(self, 109 exception_cb: ExceptionCb | None = None): 110 self._exception_cb = exception_cb 111 self._cbs = [] # type: list[Callable] 112 113 def register(self, 114 cb: typing.Callable 115 ) -> RegisterCallbackHandle: 116 """Register a callback.""" 117 self._cbs.append(cb) 118 return RegisterCallbackHandle(lambda: self._cbs.remove(cb)) 119 120 def notify(self, *args, **kwargs): 121 """Notify all registered callbacks.""" 122 for cb in self._cbs: 123 try: 124 cb(*args, **kwargs) 125 except Exception as e: 126 if self._exception_cb: 127 self._exception_cb(e) 128 else: 129 raise
Registry that enables callback registration and notification.
Callbacks in the registry are notified sequentially with
CallbackRegistry.notify
. If a callback raises an exception, the
exception is caught and exception_cb
handler is called. Notification of
subsequent callbacks is not interrupted. If handler is None
, the
exception is reraised and no subsequent callback is notified.
Example::
x = []
y = []
registry = CallbackRegistry()
registry.register(x.append)
registry.notify(1)
with registry.register(y.append):
registry.notify(2)
registry.notify(3)
assert x == [1, 2, 3]
assert y == [2]
113 def register(self, 114 cb: typing.Callable 115 ) -> RegisterCallbackHandle: 116 """Register a callback.""" 117 self._cbs.append(cb) 118 return RegisterCallbackHandle(lambda: self._cbs.remove(cb))
Register a callback.
120 def notify(self, *args, **kwargs): 121 """Notify all registered callbacks.""" 122 for cb in self._cbs: 123 try: 124 cb(*args, **kwargs) 125 except Exception as e: 126 if self._exception_cb: 127 self._exception_cb(e) 128 else: 129 raise
Notify all registered callbacks.
132def get_unused_tcp_port(host: str = '127.0.0.1') -> int: 133 """Search for unused TCP port""" 134 with contextlib.closing(socket.socket()) as sock: 135 sock.bind((host, 0)) 136 return sock.getsockname()[1]
Search for unused TCP port
139def get_unused_udp_port(host: str = '127.0.0.1') -> int: 140 """Search for unused UDP port""" 141 with contextlib.closing(socket.socket(type=socket.SOCK_DGRAM)) as sock: 142 sock.bind((host, 0)) 143 return sock.getsockname()[1]
Search for unused UDP port
146class BytesBuffer: 147 """Bytes buffer 148 149 All data added to BytesBuffer is considered immutable - it's content 150 (including size) should not be modified. 151 152 """ 153 154 def __init__(self): 155 self._data = collections.deque() 156 self._data_len = 0 157 158 def __len__(self): 159 return self._data_len 160 161 def add(self, data: Bytes): 162 """Add data""" 163 if not data: 164 return 165 166 self._data.append(data) 167 self._data_len += len(data) 168 169 def read(self, n: int = -1) -> Bytes: 170 """Read up to `n` bytes 171 172 If ``n < 0``, read all data. 173 174 """ 175 if n == 0: 176 return b'' 177 178 if n < 0 or n >= self._data_len: 179 data, self._data = self._data, collections.deque() 180 data_len, self._data_len = self._data_len, 0 181 182 else: 183 data = collections.deque() 184 data_len = 0 185 186 while data_len < n: 187 head = self._data.popleft() 188 self._data_len -= len(head) 189 190 if data_len + len(head) <= n: 191 data.append(head) 192 data_len += len(head) 193 194 else: 195 head = memoryview(head) 196 head1, head2 = head[:n-data_len], head[n-data_len:] 197 198 data.append(head1) 199 data_len += len(head1) 200 201 self._data.appendleft(head2) 202 self._data_len += len(head2) 203 204 if len(data) < 1: 205 return b'' 206 207 if len(data) < 2: 208 return data[0] 209 210 data_bytes = bytearray(data_len) 211 data_bytes_len = 0 212 213 while data: 214 head = data.popleft() 215 data_bytes[data_bytes_len:data_bytes_len+len(head)] = head 216 data_bytes_len += len(head) 217 218 return data_bytes 219 220 def clear(self) -> int: 221 """Clear data and return number of bytes cleared""" 222 self._data.clear() 223 data_len, self._data_len = self._data_len, 0 224 return data_len
Bytes buffer
All data added to BytesBuffer is considered immutable - it's content (including size) should not be modified.
161 def add(self, data: Bytes): 162 """Add data""" 163 if not data: 164 return 165 166 self._data.append(data) 167 self._data_len += len(data)
Add data
169 def read(self, n: int = -1) -> Bytes: 170 """Read up to `n` bytes 171 172 If ``n < 0``, read all data. 173 174 """ 175 if n == 0: 176 return b'' 177 178 if n < 0 or n >= self._data_len: 179 data, self._data = self._data, collections.deque() 180 data_len, self._data_len = self._data_len, 0 181 182 else: 183 data = collections.deque() 184 data_len = 0 185 186 while data_len < n: 187 head = self._data.popleft() 188 self._data_len -= len(head) 189 190 if data_len + len(head) <= n: 191 data.append(head) 192 data_len += len(head) 193 194 else: 195 head = memoryview(head) 196 head1, head2 = head[:n-data_len], head[n-data_len:] 197 198 data.append(head1) 199 data_len += len(head1) 200 201 self._data.appendleft(head2) 202 self._data_len += len(head2) 203 204 if len(data) < 1: 205 return b'' 206 207 if len(data) < 2: 208 return data[0] 209 210 data_bytes = bytearray(data_len) 211 data_bytes_len = 0 212 213 while data: 214 head = data.popleft() 215 data_bytes[data_bytes_len:data_bytes_len+len(head)] = head 216 data_bytes_len += len(head) 217 218 return data_bytes
Read up to n
bytes
If n < 0
, read all data.