diff --git a/dissect/util/stream.py b/dissect/util/stream.py index 7dc3b84..af93116 100644 --- a/dissect/util/stream.py +++ b/dissect/util/stream.py @@ -669,3 +669,104 @@ def readall(self) -> bytes: chunks.append(data) return b"".join(chunks) + + +class BitStream: + """Bit-level stream reader. + + Args: + fh: File-like object to read bits from. + """ + + def __init__(self, fh: BinaryIO): + self.fh = fh + self._byte_offset = fh.tell() + + self.buffer = 0 + self.bits = 0 + + def readable(self) -> bool: + """Indicate that the stream is readable.""" + return True + + def seekable(self) -> bool: + """Indicate that the stream is seekable.""" + return True + + def writable(self) -> bool: + """Indicate that the stream is not writable.""" + return False + + def seek(self, pos: int, whence: int = io.SEEK_SET) -> int: + """Seek the stream to the specified position in bits. + + Returns: + The new stream position after seeking. + """ + if whence == io.SEEK_SET: + byte_pos, bit_pos = divmod(pos, 8) + elif whence == io.SEEK_CUR: + current_pos = self.tell() + byte_pos, bit_pos = divmod(current_pos + pos, 8) + elif whence == io.SEEK_END: + self.fh.seek(0, io.SEEK_END) + end_pos = self.fh.tell() * 8 + byte_pos, bit_pos = divmod(end_pos + pos, 8) + else: + raise IOError("invalid whence value") + + self._byte_offset = byte_pos + self.bits = 0 + self.buffer = 0 + + if bit_pos > 0: + self.read(bit_pos) + + return self.tell() + + def tell(self) -> int: + """Get the current position in the stream in bits.""" + return (self._byte_offset * 8) - self.bits + + def read(self, n: int) -> int: + """Read n bits from the stream. + + Args: + n: Number of bits to read. + """ + value = self.peek(n) + self.remove(n) + return value + + def peek(self, n: int) -> int: + """Peek n bits from the stream without advancing. + + Args: + n: Number of bits to peek. + """ + if n == 0: + return 0 + + if n > self.bits: + while self.bits < n: + num_bytes = (n - self.bits + 7) // 8 + self.fh.seek(self._byte_offset) + if not (buf := self.fh.read(num_bytes)): + break + + new_bits = int.from_bytes(buf, "big") + num_new_bits = len(buf) * 8 + self.buffer = (self.buffer << num_new_bits) | new_bits + self.bits += num_new_bits + self._byte_offset += len(buf) + + return self.buffer >> (self.bits - min(n, self.bits)) + + def remove(self, n: int) -> None: + """Remove n bits from the stream. + + Args: + n: Number of bits to remove. + """ + self.bits -= min(n, self.bits) + self.buffer &= (1 << (self.bits)) - 1 diff --git a/tests/test_stream.py b/tests/test_stream.py index 0aae125..98ec45d 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -225,6 +225,67 @@ def test_zlib_stream() -> None: assert fh.read() == data +def test_bitstream() -> None: + """Test for correct bit reading behavior.""" + data = bytes([0b10101010, 0b11001100, 0b11110000, 0b00001111, 0b11111111, 0b00000000, 0b10101010, 0b01010101]) + fh = io.BytesIO(data) + bitstream = stream.BitStream(fh) + + assert bitstream.peek(4) == 0b1010 + assert bitstream.tell() == 0 + + assert bitstream.peek(4) == 0b1010 + assert bitstream.tell() == 0 + + assert bitstream.read(4) == 0b1010 + assert bitstream.tell() == 4 + + assert bitstream.peek(8) == 0b10101100 + assert bitstream.tell() == 4 + + assert bitstream.read(4) == 0b1010 + assert bitstream.tell() == 8 + + assert bitstream.peek(8) == 0b11001100 + assert bitstream.tell() == 8 + + assert bitstream.read(8) == 0b11001100 + assert bitstream.tell() == 16 + + assert bitstream.peek(4) == 0b1111 + assert bitstream.tell() == 16 + + bitstream.remove(4) + assert bitstream.tell() == 20 + + assert bitstream.peek(4) == 0b0000 + assert bitstream.tell() == 20 + + assert bitstream.read(8) == 0b00000000 + assert bitstream.tell() == 28 + + assert bitstream.read(33) == 0b11111111_11110000_00001010_10100101_0 + assert bitstream.tell() == 61 + + assert bitstream.seek(0) == 0 + assert bitstream.peek(16) == 0b10101010_11001100 + + assert bitstream.seek(4, io.SEEK_CUR) == 4 + assert bitstream.peek(8) == 0b10101100 + + assert bitstream.seek(-8, io.SEEK_END) == 56 + assert bitstream.peek(8) == 0b01010101 + + +def test_bitstream_empty() -> None: + """Test reading from an empty bit stream.""" + fh = io.BytesIO(b"") + bitstream = stream.BitStream(fh) + + assert bitstream.peek(8) == 0 + assert bitstream.read(8) == 0 + + class NullStream(stream.AlignedStream): def __init__(self, size: int | None, align: int = stream.STREAM_BUFFER_SIZE): super().__init__(size)