diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index a8280dbad..1e4377397 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from decimal import Decimal from io import BytesIO -from typing import Any, Optional, Set, Tuple +from typing import IO, Any, Optional, Set, Tuple import numpy as np from numpy import typing as npt @@ -366,6 +366,77 @@ def __init__(self): super().__init__(np.int64) +def _varuint_encode(obj: int) -> bytes: + if obj < 0: + raise ValueError(f'Expected non-negative integer, but got: {obj}.') + ret = [] + while True: + byte = obj & 0x7F + obj >>= 7 + if obj: + ret.append(0x80 | byte) + else: + ret.append(byte) + break + return bytes(ret) + + +def _varint_encode(obj: int) -> bytes: + if 0 <= obj: + obj = obj << 1 + else: + obj = ((-obj) << 1) | 1 + return _varuint_encode(obj) + + +def _varuint_decode(stream: IO[bytes]) -> int: + obj = 0 + shift = 0 + while True: + byte, = stream.read(1) + obj |= (byte & 0x7F) << shift + if 0x80 <= byte: + shift += 7 + else: + break + return obj + + +def _varint_decode(stream: IO[bytes]) -> int: + obj = _varuint_decode(stream) + if obj & 1: + obj = -(obj >> 1) + else: + obj >>= 1 + return obj + + +class VarUInt(Encoding): + """Store an unsigned integer as a base-128 varint.""" + + @classmethod + def encode(cls, obj: int) -> bytes: + return _varuint_encode(obj) + + @classmethod + def decode(cls, data: bytes) -> int: + stream = BytesIO(data) + return _varuint_decode(stream) + + +class VarInt(Encoding): + """Store an integer as a base-128 varint.""" + + @classmethod + def encode(cls, obj: int) -> bytes: + return _varint_encode(obj) + + @classmethod + def decode(cls, data: bytes) -> int: + stream = BytesIO(data) + return _varint_decode(stream) + + class Float16(Scalar): """Store float16.""" @@ -531,6 +602,8 @@ def _is_valid(self, original: Any, converted: Any) -> None: 'int16': Int16, 'int32': Int32, 'int64': Int64, + 'varuint': VarUInt, + 'varint': VarInt, 'float16': Float16, 'float32': Float32, 'float64': Float64, diff --git a/tests/test_encodings.py b/tests/test_encodings.py index bc3aac670..6fccf1aa1 100644 --- a/tests/test_encodings.py +++ b/tests/test_encodings.py @@ -456,8 +456,8 @@ def test_mds_StrDecimal(self, decoded: Decimal, encoded: bytes): assert dec == decoded def test_get_mds_encodings(self): - uints = {'uint8', 'uint16', 'uint32', 'uint64'} - ints = {'int8', 'int16', 'int32', 'int64', 'str_int'} + uints = {'uint8', 'uint16', 'uint32', 'uint64', 'varuint'} + ints = {'int8', 'int16', 'int32', 'int64', 'str_int', 'varint'} floats = {'float16', 'float32', 'float64', 'str_float', 'str_decimal'} scalars = uints | ints | floats expected_encodings = { @@ -488,6 +488,19 @@ def test_mds_scalar(self, encoding: str, decoded: Union[int, float], encoded: by dec = mdsEnc.mds_decode(encoding, encoded) assert dec == decoded + def test_varints(self): + from streaming.base.format.mds.encodings import mds_decode, mds_encode + for x in range(-700, 700, 7): + y = mds_encode('varint', x) + z = mds_decode('varint', y) + print(x, y, z) + assert x == z + for x in range(0, 700, 7): + y = mds_encode('varuint', x) + z = mds_decode('varuint', y) + print(x, y, z) + assert x == z + @pytest.mark.parametrize(('enc_name', 'data'), [('bytes', b'9'), ('int', 27), ('str', 'mosaicml')]) def test_mds_encode(self, enc_name: str, data: Any):