Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion parquet/src/encodings/rle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,15 @@ impl RleEncoder {
/// Size, in number of `i32s` of buffer to use for RLE batch reading
const RLE_DECODER_INDEX_BUFFER_SIZE: usize = 1024;

/// A decoded batch from [`RleDecoder::get_batch_direct`].
#[cfg(feature = "arrow")]
pub enum RleDecodedBatch<'a> {
/// An RLE run: all values are the same index, repeated `count` times
Rle { index: i32, count: usize },
/// A batch of bit-packed indices
BitPacked(&'a [i32]),
}

/// A RLE/Bit-Packing hybrid decoder.
pub struct RleDecoder {
// Number of bits used to encode the value. Must be between [0, 64].
Expand Down Expand Up @@ -414,6 +423,52 @@ impl RleDecoder {
Ok(values_read)
}

/// Decode up to `max_values` indices and call `f` with each decoded batch.
///
/// For RLE runs, provides [`RleDecodedBatch::Rle`] so callers can fill output directly.
/// For bit-packed runs, provides [`RleDecodedBatch::BitPacked`] with decoded indices.
#[cfg(feature = "arrow")]
pub fn get_batch_direct<F>(&mut self, max_values: usize, mut f: F) -> Result<usize>
where
F: FnMut(RleDecodedBatch<'_>),
{
let mut values_read = 0;
let mut index_buf = [0i32; 1024];
while values_read < max_values {
if self.rle_left > 0 {
let num_values = cmp::min(max_values - values_read, self.rle_left as usize);
let idx = self.current_value.unwrap() as i32;
f(RleDecodedBatch::Rle {
index: idx,
count: num_values,
});
self.rle_left -= num_values as u32;
values_read += num_values;
} else if self.bit_packed_left > 0 {
let to_read = (max_values - values_read)
.min(self.bit_packed_left as usize)
.min(index_buf.len());
let bit_reader = self
.bit_reader
.as_mut()
.ok_or_else(|| general_err!("bit_reader should be set"))?;

let num_values =
bit_reader.get_batch::<i32>(&mut index_buf[..to_read], self.bit_width as usize);
if num_values == 0 {
self.bit_packed_left = 0;
continue;
}
f(RleDecodedBatch::BitPacked(&index_buf[..num_values]));
self.bit_packed_left -= num_values as u32;
values_read += num_values;
} else if !self.reload()? {
break;
}
}
Ok(values_read)
}

#[inline(never)]
pub fn skip(&mut self, num_values: usize) -> Result<usize> {
let mut values_skipped = 0;
Expand Down Expand Up @@ -458,6 +513,13 @@ impl RleDecoder {
{
assert!(buffer.len() >= max_values);

if dict.is_empty() {
return Ok(0);
}
// Clamp index to valid range to prevent UB on corrupt data.
// This is branchless (cmp+csel on ARM) and avoids bounds checks in the hot loop.
let max_idx = dict.len() - 1;

let mut values_read = 0;
while values_read < max_values {
let index_buf = self.index_buf.get_or_insert_with(|| Box::new([0; 1024]));
Expand Down Expand Up @@ -497,7 +559,9 @@ impl RleDecoder {
buffer[values_read..values_read + num_values]
.iter_mut()
.zip(index_buf[..num_values].iter())
.for_each(|(b, i)| b.clone_from(&dict[*i as usize]));
.for_each(|(b, i)| unsafe {
b.clone_from(dict.get_unchecked((*i as usize).min(max_idx)));
});
self.bit_packed_left -= num_values as u32;
values_read += num_values;
if num_values < to_read {
Expand Down
Loading