hat.util

Common utility functions

 1"""Common utility functions"""
 2
 3from hat.util import cron
 4from hat.util.bytes import (Bytes,
 5                            BytesBuffer)
 6from hat.util.callback import (RegisterCallbackHandle,
 7                               ExceptionCb,
 8                               CallbackRegistry)
 9from hat.util.first import first
10from hat.util.socket import (get_unused_tcp_port,
11                             get_unused_udp_port)
12from hat.util.sqlite3 import register_sqlite3_timestamp_converter
13
14
15__all__ = ['cron',
16           'Bytes',
17           'BytesBuffer',
18           'RegisterCallbackHandle',
19           'ExceptionCb',
20           'CallbackRegistry',
21           'first',
22           'get_unused_tcp_port',
23           'get_unused_udp_port',
24           'register_sqlite3_timestamp_converter']
Bytes = bytes | bytearray | memoryview
class BytesBuffer:
 9class BytesBuffer:
10    """Bytes buffer
11
12    All data added to BytesBuffer is considered immutable - it's content
13    (including size) should not be modified.
14
15    """
16
17    def __init__(self):
18        self._data = collections.deque()
19        self._data_len = 0
20
21    def __len__(self):
22        return self._data_len
23
24    def add(self, data: Bytes):
25        """Add data"""
26        if not data:
27            return
28
29        self._data.append(data)
30        self._data_len += len(data)
31
32    def read(self, n: int = -1) -> Bytes:
33        """Read up to `n` bytes
34
35        If ``n < 0``, read all data.
36
37        """
38        if n == 0:
39            return b''
40
41        if n < 0 or n >= self._data_len:
42            data, self._data = self._data, collections.deque()
43            data_len, self._data_len = self._data_len, 0
44
45        else:
46            data = collections.deque()
47            data_len = 0
48
49            while data_len < n:
50                head = self._data.popleft()
51                self._data_len -= len(head)
52
53                if data_len + len(head) <= n:
54                    data.append(head)
55                    data_len += len(head)
56
57                else:
58                    head = memoryview(head)
59                    head1, head2 = head[:n-data_len], head[n-data_len:]
60
61                    data.append(head1)
62                    data_len += len(head1)
63
64                    self._data.appendleft(head2)
65                    self._data_len += len(head2)
66
67        if len(data) < 1:
68            return b''
69
70        if len(data) < 2:
71            return data[0]
72
73        data_bytes = bytearray(data_len)
74        data_bytes_len = 0
75
76        while data:
77            head = data.popleft()
78            data_bytes[data_bytes_len:data_bytes_len+len(head)] = head
79            data_bytes_len += len(head)
80
81        return data_bytes
82
83    def clear(self) -> int:
84        """Clear data and return number of bytes cleared"""
85        self._data.clear()
86        data_len, self._data_len = self._data_len, 0
87        return data_len

Bytes buffer

All data added to BytesBuffer is considered immutable - it's content (including size) should not be modified.

def add(self, data: bytes | bytearray | memoryview):
24    def add(self, data: Bytes):
25        """Add data"""
26        if not data:
27            return
28
29        self._data.append(data)
30        self._data_len += len(data)

Add data

def read(self, n: int = -1) -> bytes | bytearray | memoryview:
32    def read(self, n: int = -1) -> Bytes:
33        """Read up to `n` bytes
34
35        If ``n < 0``, read all data.
36
37        """
38        if n == 0:
39            return b''
40
41        if n < 0 or n >= self._data_len:
42            data, self._data = self._data, collections.deque()
43            data_len, self._data_len = self._data_len, 0
44
45        else:
46            data = collections.deque()
47            data_len = 0
48
49            while data_len < n:
50                head = self._data.popleft()
51                self._data_len -= len(head)
52
53                if data_len + len(head) <= n:
54                    data.append(head)
55                    data_len += len(head)
56
57                else:
58                    head = memoryview(head)
59                    head1, head2 = head[:n-data_len], head[n-data_len:]
60
61                    data.append(head1)
62                    data_len += len(head1)
63
64                    self._data.appendleft(head2)
65                    self._data_len += len(head2)
66
67        if len(data) < 1:
68            return b''
69
70        if len(data) < 2:
71            return data[0]
72
73        data_bytes = bytearray(data_len)
74        data_bytes_len = 0
75
76        while data:
77            head = data.popleft()
78            data_bytes[data_bytes_len:data_bytes_len+len(head)] = head
79            data_bytes_len += len(head)
80
81        return data_bytes

Read up to n bytes

If n < 0, read all data.

def clear(self) -> int:
83    def clear(self) -> int:
84        """Clear data and return number of bytes cleared"""
85        self._data.clear()
86        data_len, self._data_len = self._data_len, 0
87        return data_len

Clear data and return number of bytes cleared

class RegisterCallbackHandle(typing.NamedTuple):
 6class RegisterCallbackHandle(typing.NamedTuple):
 7    """Handle for canceling callback registration."""
 8
 9    cancel: Callable[[], None]
10    """cancel callback registration"""
11
12    def __enter__(self):
13        return self
14
15    def __exit__(self, *args):
16        self.cancel()

Handle for canceling callback registration.

RegisterCallbackHandle(cancel: Callable[[], None])

Create new instance of RegisterCallbackHandle(cancel,)

cancel: Callable[[], None]

cancel callback registration

ExceptionCb = collections.abc.Callable[[Exception], None]
class CallbackRegistry:
23class CallbackRegistry:
24    """Registry that enables callback registration and notification.
25
26    Callbacks in the registry are notified sequentially with
27    `CallbackRegistry.notify`. If a callback raises an exception, the
28    exception is caught and `exception_cb` handler is called. Notification of
29    subsequent callbacks is not interrupted. If handler is `None`, the
30    exception is reraised and no subsequent callback is notified.
31
32    Example::
33
34        x = []
35        y = []
36        registry = CallbackRegistry()
37
38        registry.register(x.append)
39        registry.notify(1)
40
41        with registry.register(y.append):
42            registry.notify(2)
43
44        registry.notify(3)
45
46        assert x == [1, 2, 3]
47        assert y == [2]
48
49    """
50
51    def __init__(self,
52                 exception_cb: ExceptionCb | None = None):
53        self._exception_cb = exception_cb
54        self._cbs = []  # type: list[Callable]
55
56    def register(self,
57                 cb: Callable
58                 ) -> RegisterCallbackHandle:
59        """Register a callback."""
60        self._cbs.append(cb)
61        return RegisterCallbackHandle(lambda: self._cbs.remove(cb))
62
63    def notify(self, *args, **kwargs):
64        """Notify all registered callbacks."""
65        for cb in self._cbs:
66            try:
67                cb(*args, **kwargs)
68            except Exception as e:
69                if self._exception_cb:
70                    self._exception_cb(e)
71                else:
72                    raise

Registry that enables callback registration and notification.

Callbacks in the registry are notified sequentially with CallbackRegistry.notify. If a callback raises an exception, the exception is caught and exception_cb handler is called. Notification of subsequent callbacks is not interrupted. If handler is None, the exception is reraised and no subsequent callback is notified.

Example::

x = []
y = []
registry = CallbackRegistry()

registry.register(x.append)
registry.notify(1)

with registry.register(y.append):
    registry.notify(2)

registry.notify(3)

assert x == [1, 2, 3]
assert y == [2]
CallbackRegistry(exception_cb: Callable[[Exception], None] | None = None)
51    def __init__(self,
52                 exception_cb: ExceptionCb | None = None):
53        self._exception_cb = exception_cb
54        self._cbs = []  # type: list[Callable]
def register(self, cb: Callable) -> RegisterCallbackHandle:
56    def register(self,
57                 cb: Callable
58                 ) -> RegisterCallbackHandle:
59        """Register a callback."""
60        self._cbs.append(cb)
61        return RegisterCallbackHandle(lambda: self._cbs.remove(cb))

Register a callback.

def notify(self, *args, **kwargs):
63    def notify(self, *args, **kwargs):
64        """Notify all registered callbacks."""
65        for cb in self._cbs:
66            try:
67                cb(*args, **kwargs)
68            except Exception as e:
69                if self._exception_cb:
70                    self._exception_cb(e)
71                else:
72                    raise

Notify all registered callbacks.

def first( xs: Iterable[~T], fn: Callable[[~T], typing.Any] = <function <lambda>>, default: Optional[~T] = None) -> Optional[~T]:
 9def first(xs: Iterable[T],
10          fn: Callable[[T], typing.Any] = lambda _: True,
11          default: T | None = None
12          ) -> T | None:
13    """Return the first element from iterable that satisfies predicate `fn`,
14    or `default` if no such element exists.
15
16    Result of predicate `fn` can be of any type. Predicate is satisfied if it's
17    return value is truthy.
18
19    Args:
20        xs: collection
21        fn: predicate
22        default: default value
23
24    Example::
25
26        assert first(range(3)) == 0
27        assert first(range(3), lambda x: x > 1) == 2
28        assert first(range(3), lambda x: x > 2) is None
29        assert first(range(3), lambda x: x > 2, 123) == 123
30        assert first({1: 'a', 2: 'b', 3: 'c'}) == 1
31        assert first([], default=123) == 123
32
33    """
34    return next((i for i in xs if fn(i)), default)

Return the first element from iterable that satisfies predicate fn, or default if no such element exists.

Result of predicate fn can be of any type. Predicate is satisfied if it's return value is truthy.

Arguments:
  • xs: collection
  • fn: predicate
  • default: default value

Example::

assert first(range(3)) == 0
assert first(range(3), lambda x: x > 1) == 2
assert first(range(3), lambda x: x > 2) is None
assert first(range(3), lambda x: x > 2, 123) == 123
assert first({1: 'a', 2: 'b', 3: 'c'}) == 1
assert first([], default=123) == 123
def get_unused_tcp_port(host: str = '127.0.0.1') -> int:
 6def get_unused_tcp_port(host: str = '127.0.0.1') -> int:
 7    """Search for unused TCP port"""
 8    with contextlib.closing(socket.socket()) as sock:
 9        sock.bind((host, 0))
10        return sock.getsockname()[1]

Search for unused TCP port

def get_unused_udp_port(host: str = '127.0.0.1') -> int:
13def get_unused_udp_port(host: str = '127.0.0.1') -> int:
14    """Search for unused UDP port"""
15    with contextlib.closing(socket.socket(type=socket.SOCK_DGRAM)) as sock:
16        sock.bind((host, 0))
17        return sock.getsockname()[1]

Search for unused UDP port

def register_sqlite3_timestamp_converter():
 6def register_sqlite3_timestamp_converter():
 7    """Register modified timestamp converter
 8
 9    This converter is modification of standard library convertor taking into
10    account possible timezone info.
11
12    """
13
14    def convert_timestamp(val: bytes) -> datetime.datetime:
15        datepart, timetzpart = val.split(b" ")
16        if b"+" in timetzpart:
17            tzsign = 1
18            timepart, tzpart = timetzpart.split(b"+")
19        elif b"-" in timetzpart:
20            tzsign = -1
21            timepart, tzpart = timetzpart.split(b"-")
22        else:
23            timepart, tzpart = timetzpart, None
24        year, month, day = map(int, datepart.split(b"-"))
25        timepart_full = timepart.split(b".")
26        hours, minutes, seconds = map(int, timepart_full[0].split(b":"))
27        if len(timepart_full) == 2:
28            microseconds = int('{:0<6.6}'.format(timepart_full[1].decode()))
29        else:
30            microseconds = 0
31        if tzpart:
32            tzhours, tzminutes = map(int, tzpart.split(b":"))
33            tz = datetime.timezone(
34                tzsign * datetime.timedelta(hours=tzhours, minutes=tzminutes))
35        else:
36            tz = None
37
38        val = datetime.datetime(year, month, day, hours, minutes, seconds,
39                                microseconds, tz)
40        return val
41
42    sqlite3.register_converter("timestamp", convert_timestamp)

Register modified timestamp converter

This converter is modification of standard library convertor taking into account possible timezone info.