hat.util

Common utility functions

  1"""Common utility functions"""
  2
  3import collections
  4import contextlib
  5import inspect
  6import socket
  7import typing
  8import warnings
  9
 10
 11T = typing.TypeVar('T')
 12
 13Bytes: typing.TypeAlias = bytes | bytearray | memoryview
 14
 15
 16def register_type_alias(name: str):
 17    """Register type alias
 18
 19    This function is temporary hack replacement for typing.TypeAlias.
 20
 21    It is expected that calling location will have `name` in local namespace
 22    with type value. This function will wrap that type inside `typing.TypeVar`
 23    and update annotations.
 24
 25    """
 26    warnings.warn("use typing.TypeAlias", DeprecationWarning)
 27    frame = inspect.stack()[1][0]
 28    f_locals = frame.f_locals
 29    t = f_locals[name]
 30    f_locals[name] = typing.TypeVar(name, t, t)
 31    f_locals.setdefault('__annotations__', {})[name] = typing.Type[t]
 32
 33
 34def first(xs: typing.Iterable[T],
 35          fn: typing.Callable[[T], typing.Any] = lambda _: True,
 36          default: T | None = None
 37          ) -> T | None:
 38    """Return the first element from iterable that satisfies predicate `fn`,
 39    or `default` if no such element exists.
 40
 41    Result of predicate `fn` can be of any type. Predicate is satisfied if it's
 42    return value is truthy.
 43
 44    Args:
 45        xs: collection
 46        fn: predicate
 47        default: default value
 48
 49    Example::
 50
 51        assert first(range(3)) == 0
 52        assert first(range(3), lambda x: x > 1) == 2
 53        assert first(range(3), lambda x: x > 2) is None
 54        assert first(range(3), lambda x: x > 2, 123) == 123
 55        assert first({1: 'a', 2: 'b', 3: 'c'}) == 1
 56        assert first([], default=123) == 123
 57
 58    """
 59    return next((i for i in xs if fn(i)), default)
 60
 61
 62class RegisterCallbackHandle(typing.NamedTuple):
 63    """Handle for canceling callback registration."""
 64
 65    cancel: typing.Callable[[], None]
 66    """cancel callback registration"""
 67
 68    def __enter__(self):
 69        return self
 70
 71    def __exit__(self, *args):
 72        self.cancel()
 73
 74
 75ExceptionCb: typing.TypeAlias = typing.Callable[[Exception], None]
 76"""Exception callback"""
 77
 78
 79class CallbackRegistry:
 80    """Registry that enables callback registration and notification.
 81
 82    Callbacks in the registry are notified sequentially with
 83    `CallbackRegistry.notify`. If a callback raises an exception, the
 84    exception is caught and `exception_cb` handler is called. Notification of
 85    subsequent callbacks is not interrupted. If handler is `None`, the
 86    exception is reraised and no subsequent callback is notified.
 87
 88    Example::
 89
 90        x = []
 91        y = []
 92        registry = CallbackRegistry()
 93
 94        registry.register(x.append)
 95        registry.notify(1)
 96
 97        with registry.register(y.append):
 98            registry.notify(2)
 99
100        registry.notify(3)
101
102        assert x == [1, 2, 3]
103        assert y == [2]
104
105    """
106
107    def __init__(self,
108                 exception_cb: ExceptionCb | None = None):
109        self._exception_cb = exception_cb
110        self._cbs = []  # type: list[Callable]
111
112    def register(self,
113                 cb: typing.Callable
114                 ) -> RegisterCallbackHandle:
115        """Register a callback."""
116        self._cbs.append(cb)
117        return RegisterCallbackHandle(lambda: self._cbs.remove(cb))
118
119    def notify(self, *args, **kwargs):
120        """Notify all registered callbacks."""
121        for cb in self._cbs:
122            try:
123                cb(*args, **kwargs)
124            except Exception as e:
125                if self._exception_cb:
126                    self._exception_cb(e)
127                else:
128                    raise
129
130
131def get_unused_tcp_port(host: str = '127.0.0.1') -> int:
132    """Search for unused TCP port"""
133    with contextlib.closing(socket.socket()) as sock:
134        sock.bind((host, 0))
135        return sock.getsockname()[1]
136
137
138def get_unused_udp_port(host: str = '127.0.0.1') -> int:
139    """Search for unused UDP port"""
140    with contextlib.closing(socket.socket(type=socket.SOCK_DGRAM)) as sock:
141        sock.bind((host, 0))
142        return sock.getsockname()[1]
143
144
145class BytesBuffer:
146    """Bytes buffer
147
148    All data added to BytesBuffer is considered immutable - it's content
149    (including size) should not be modified.
150
151    """
152
153    def __init__(self):
154        self._data = collections.deque()
155        self._data_len = 0
156
157    def __len__(self):
158        return self._data_len
159
160    def add(self, data: Bytes):
161        """Add data"""
162        if not data:
163            return
164
165        self._data.append(data)
166        self._data_len += len(data)
167
168    def read(self, n: int = -1) -> Bytes:
169        """Read up to `n` bytes
170
171        If ``n < 0``, read all data.
172
173        """
174        if n == 0:
175            return b''
176
177        if n < 0 or n >= self._data_len:
178            data, self._data = self._data, collections.deque()
179            data_len, self._data_len = self._data_len, 0
180
181        else:
182            data = collections.deque()
183            data_len = 0
184
185            while data_len < n:
186                head = self._data.popleft()
187                self._data_len -= len(head)
188
189                if data_len + len(head) <= n:
190                    data.append(head)
191                    data_len += len(head)
192
193                else:
194                    head = memoryview(head)
195                    head1, head2 = head[:n-data_len], head[n-data_len:]
196
197                    data.append(head1)
198                    data_len += len(head1)
199
200                    self._data.appendleft(head2)
201                    self._data_len += len(head2)
202
203        if len(data) < 1:
204            return b''
205
206        if len(data) < 2:
207            return data[0]
208
209        data_bytes = bytearray(data_len)
210        data_bytes_len = 0
211
212        while data:
213            head = data.popleft()
214            data_bytes[data_bytes_len:data_bytes_len+len(head)] = head
215            data_bytes_len += len(head)
216
217        return data_bytes
218
219    def clear(self) -> int:
220        """Clear data and return number of bytes cleared"""
221        self._data.clear()
222        data_len, self._data_len = self._data_len, 0
223        return data_len
Bytes: TypeAlias = bytes | bytearray | memoryview
def register_type_alias(name: str):
17def register_type_alias(name: str):
18    """Register type alias
19
20    This function is temporary hack replacement for typing.TypeAlias.
21
22    It is expected that calling location will have `name` in local namespace
23    with type value. This function will wrap that type inside `typing.TypeVar`
24    and update annotations.
25
26    """
27    warnings.warn("use typing.TypeAlias", DeprecationWarning)
28    frame = inspect.stack()[1][0]
29    f_locals = frame.f_locals
30    t = f_locals[name]
31    f_locals[name] = typing.TypeVar(name, t, t)
32    f_locals.setdefault('__annotations__', {})[name] = typing.Type[t]

Register type alias

This function is temporary hack replacement for typing.TypeAlias.

It is expected that calling location will have name in local namespace with type value. This function will wrap that type inside typing.TypeVar and update annotations.

def first( xs: Iterable[~T], fn: Callable[[~T], Any] = <function <lambda>>, default: Optional[~T] = None) -> Optional[~T]:
35def first(xs: typing.Iterable[T],
36          fn: typing.Callable[[T], typing.Any] = lambda _: True,
37          default: T | None = None
38          ) -> T | None:
39    """Return the first element from iterable that satisfies predicate `fn`,
40    or `default` if no such element exists.
41
42    Result of predicate `fn` can be of any type. Predicate is satisfied if it's
43    return value is truthy.
44
45    Args:
46        xs: collection
47        fn: predicate
48        default: default value
49
50    Example::
51
52        assert first(range(3)) == 0
53        assert first(range(3), lambda x: x > 1) == 2
54        assert first(range(3), lambda x: x > 2) is None
55        assert first(range(3), lambda x: x > 2, 123) == 123
56        assert first({1: 'a', 2: 'b', 3: 'c'}) == 1
57        assert first([], default=123) == 123
58
59    """
60    return next((i for i in xs if fn(i)), default)

Return the first element from iterable that satisfies predicate fn, or default if no such element exists.

Result of predicate fn can be of any type. Predicate is satisfied if it's return value is truthy.

Arguments:
  • xs: collection
  • fn: predicate
  • default: default value

Example::

assert first(range(3)) == 0
assert first(range(3), lambda x: x > 1) == 2
assert first(range(3), lambda x: x > 2) is None
assert first(range(3), lambda x: x > 2, 123) == 123
assert first({1: 'a', 2: 'b', 3: 'c'}) == 1
assert first([], default=123) == 123
class RegisterCallbackHandle(typing.NamedTuple):
63class RegisterCallbackHandle(typing.NamedTuple):
64    """Handle for canceling callback registration."""
65
66    cancel: typing.Callable[[], None]
67    """cancel callback registration"""
68
69    def __enter__(self):
70        return self
71
72    def __exit__(self, *args):
73        self.cancel()

Handle for canceling callback registration.

RegisterCallbackHandle(cancel: Callable[[], NoneType])

Create new instance of RegisterCallbackHandle(cancel,)

cancel: Callable[[], NoneType]

cancel callback registration

Inherited Members
builtins.tuple
index
count
ExceptionCb: TypeAlias = typing.Callable[[Exception], NoneType]

Exception callback

class CallbackRegistry:
 80class CallbackRegistry:
 81    """Registry that enables callback registration and notification.
 82
 83    Callbacks in the registry are notified sequentially with
 84    `CallbackRegistry.notify`. If a callback raises an exception, the
 85    exception is caught and `exception_cb` handler is called. Notification of
 86    subsequent callbacks is not interrupted. If handler is `None`, the
 87    exception is reraised and no subsequent callback is notified.
 88
 89    Example::
 90
 91        x = []
 92        y = []
 93        registry = CallbackRegistry()
 94
 95        registry.register(x.append)
 96        registry.notify(1)
 97
 98        with registry.register(y.append):
 99            registry.notify(2)
100
101        registry.notify(3)
102
103        assert x == [1, 2, 3]
104        assert y == [2]
105
106    """
107
108    def __init__(self,
109                 exception_cb: ExceptionCb | None = None):
110        self._exception_cb = exception_cb
111        self._cbs = []  # type: list[Callable]
112
113    def register(self,
114                 cb: typing.Callable
115                 ) -> RegisterCallbackHandle:
116        """Register a callback."""
117        self._cbs.append(cb)
118        return RegisterCallbackHandle(lambda: self._cbs.remove(cb))
119
120    def notify(self, *args, **kwargs):
121        """Notify all registered callbacks."""
122        for cb in self._cbs:
123            try:
124                cb(*args, **kwargs)
125            except Exception as e:
126                if self._exception_cb:
127                    self._exception_cb(e)
128                else:
129                    raise

Registry that enables callback registration and notification.

Callbacks in the registry are notified sequentially with CallbackRegistry.notify. If a callback raises an exception, the exception is caught and exception_cb handler is called. Notification of subsequent callbacks is not interrupted. If handler is None, the exception is reraised and no subsequent callback is notified.

Example::

x = []
y = []
registry = CallbackRegistry()

registry.register(x.append)
registry.notify(1)

with registry.register(y.append):
    registry.notify(2)

registry.notify(3)

assert x == [1, 2, 3]
assert y == [2]
CallbackRegistry(exception_cb: Optional[Callable[[Exception], NoneType]] = None)
108    def __init__(self,
109                 exception_cb: ExceptionCb | None = None):
110        self._exception_cb = exception_cb
111        self._cbs = []  # type: list[Callable]
def register(self, cb: Callable) -> hat.util.RegisterCallbackHandle:
113    def register(self,
114                 cb: typing.Callable
115                 ) -> RegisterCallbackHandle:
116        """Register a callback."""
117        self._cbs.append(cb)
118        return RegisterCallbackHandle(lambda: self._cbs.remove(cb))

Register a callback.

def notify(self, *args, **kwargs):
120    def notify(self, *args, **kwargs):
121        """Notify all registered callbacks."""
122        for cb in self._cbs:
123            try:
124                cb(*args, **kwargs)
125            except Exception as e:
126                if self._exception_cb:
127                    self._exception_cb(e)
128                else:
129                    raise

Notify all registered callbacks.

def get_unused_tcp_port(host: str = '127.0.0.1') -> int:
132def get_unused_tcp_port(host: str = '127.0.0.1') -> int:
133    """Search for unused TCP port"""
134    with contextlib.closing(socket.socket()) as sock:
135        sock.bind((host, 0))
136        return sock.getsockname()[1]

Search for unused TCP port

def get_unused_udp_port(host: str = '127.0.0.1') -> int:
139def get_unused_udp_port(host: str = '127.0.0.1') -> int:
140    """Search for unused UDP port"""
141    with contextlib.closing(socket.socket(type=socket.SOCK_DGRAM)) as sock:
142        sock.bind((host, 0))
143        return sock.getsockname()[1]

Search for unused UDP port

class BytesBuffer:
146class BytesBuffer:
147    """Bytes buffer
148
149    All data added to BytesBuffer is considered immutable - it's content
150    (including size) should not be modified.
151
152    """
153
154    def __init__(self):
155        self._data = collections.deque()
156        self._data_len = 0
157
158    def __len__(self):
159        return self._data_len
160
161    def add(self, data: Bytes):
162        """Add data"""
163        if not data:
164            return
165
166        self._data.append(data)
167        self._data_len += len(data)
168
169    def read(self, n: int = -1) -> Bytes:
170        """Read up to `n` bytes
171
172        If ``n < 0``, read all data.
173
174        """
175        if n == 0:
176            return b''
177
178        if n < 0 or n >= self._data_len:
179            data, self._data = self._data, collections.deque()
180            data_len, self._data_len = self._data_len, 0
181
182        else:
183            data = collections.deque()
184            data_len = 0
185
186            while data_len < n:
187                head = self._data.popleft()
188                self._data_len -= len(head)
189
190                if data_len + len(head) <= n:
191                    data.append(head)
192                    data_len += len(head)
193
194                else:
195                    head = memoryview(head)
196                    head1, head2 = head[:n-data_len], head[n-data_len:]
197
198                    data.append(head1)
199                    data_len += len(head1)
200
201                    self._data.appendleft(head2)
202                    self._data_len += len(head2)
203
204        if len(data) < 1:
205            return b''
206
207        if len(data) < 2:
208            return data[0]
209
210        data_bytes = bytearray(data_len)
211        data_bytes_len = 0
212
213        while data:
214            head = data.popleft()
215            data_bytes[data_bytes_len:data_bytes_len+len(head)] = head
216            data_bytes_len += len(head)
217
218        return data_bytes
219
220    def clear(self) -> int:
221        """Clear data and return number of bytes cleared"""
222        self._data.clear()
223        data_len, self._data_len = self._data_len, 0
224        return data_len

Bytes buffer

All data added to BytesBuffer is considered immutable - it's content (including size) should not be modified.

def add(self, data: bytes | bytearray | memoryview):
161    def add(self, data: Bytes):
162        """Add data"""
163        if not data:
164            return
165
166        self._data.append(data)
167        self._data_len += len(data)

Add data

def read(self, n: int = -1) -> bytes | bytearray | memoryview:
169    def read(self, n: int = -1) -> Bytes:
170        """Read up to `n` bytes
171
172        If ``n < 0``, read all data.
173
174        """
175        if n == 0:
176            return b''
177
178        if n < 0 or n >= self._data_len:
179            data, self._data = self._data, collections.deque()
180            data_len, self._data_len = self._data_len, 0
181
182        else:
183            data = collections.deque()
184            data_len = 0
185
186            while data_len < n:
187                head = self._data.popleft()
188                self._data_len -= len(head)
189
190                if data_len + len(head) <= n:
191                    data.append(head)
192                    data_len += len(head)
193
194                else:
195                    head = memoryview(head)
196                    head1, head2 = head[:n-data_len], head[n-data_len:]
197
198                    data.append(head1)
199                    data_len += len(head1)
200
201                    self._data.appendleft(head2)
202                    self._data_len += len(head2)
203
204        if len(data) < 1:
205            return b''
206
207        if len(data) < 2:
208            return data[0]
209
210        data_bytes = bytearray(data_len)
211        data_bytes_len = 0
212
213        while data:
214            head = data.popleft()
215            data_bytes[data_bytes_len:data_bytes_len+len(head)] = head
216            data_bytes_len += len(head)
217
218        return data_bytes

Read up to n bytes

If n < 0, read all data.

def clear(self) -> int:
220    def clear(self) -> int:
221        """Clear data and return number of bytes cleared"""
222        self._data.clear()
223        data_len, self._data_len = self._data_len, 0
224        return data_len

Clear data and return number of bytes cleared