1
0
Fork 0
mirror of https://codeberg.org/icewind/bitbuffer.git synced 2026-06-03 08:34:07 +02:00

remove the need for padding by special casing reading near the end of the data

This commit is contained in:
Robin Appelman 2020-12-05 21:41:32 +01:00
commit a276e5a457
5 changed files with 167 additions and 118 deletions

View file

@ -205,7 +205,7 @@ fn derive_bitread_trait(
};
let extra_param_call = if extra_param.is_some() {
Some(quote!(input_size))
Some(quote!(input_size,))
} else {
None
};
@ -225,9 +225,9 @@ fn derive_bitread_trait(
// if the read has a predicable size, we can do the bounds check in one go
match <Self as #trait_def>::#size_method_name(#extra_param_call) {
Some(size) => {
stream.check_read(size)?;
let end = stream.check_read(size)?;
unsafe {
<Self as #trait_def>::read_unchecked(stream, #extra_param_call)
<Self as #trait_def>::read_unchecked(stream, #extra_param_call end)
}
},
None => {
@ -236,7 +236,7 @@ fn derive_bitread_trait(
}
}
unsafe fn read_unchecked(stream: &mut ::bitbuffer::BitReadStream<#endianness_ident>#extra_param) -> ::bitbuffer::Result<Self> {
unsafe fn read_unchecked(stream: &mut ::bitbuffer::BitReadStream<#endianness_ident>#extra_param, end: bool) -> ::bitbuffer::Result<Self> {
#parsed_unchecked
}
@ -267,13 +267,13 @@ fn parse(data: Data, struct_name: &Ident, attrs: &[Attribute], unchecked: bool)
quote_spanned! { span =>
{
let _size: usize = #size;
stream.read_sized_unchecked::<#field_type>(_size)?
stream.read_sized_unchecked::<#field_type>(_size, end)?
}
}
}
None => {
quote_spanned! { span =>
stream.read_unchecked::<#field_type>()?
stream.read_unchecked::<#field_type>(end)?
}
}
}

View file

@ -98,7 +98,7 @@ pub trait BitRead<E: Endianness>: Sized {
/// any other validations (e.g. checking for valid utf8) still needs to be done
#[doc(hidden)]
#[inline]
unsafe fn read_unchecked(stream: &mut BitReadStream<E>) -> Result<Self> {
unsafe fn read_unchecked(stream: &mut BitReadStream<E>, _end: bool) -> Result<Self> {
Self::read(stream)
}
@ -130,8 +130,8 @@ macro_rules! impl_read_int {
}
#[inline]
unsafe fn read_unchecked(stream: &mut BitReadStream<E>) -> Result<$type> {
Ok(stream.read_int_unchecked::<$type>(size_of::<$type>() * 8))
unsafe fn read_unchecked(stream: &mut BitReadStream<E>, end: bool) -> Result<$type> {
Ok(stream.read_int_unchecked::<$type>(size_of::<$type>() * 8, end))
}
#[inline]
@ -151,9 +151,12 @@ macro_rules! impl_read_int_nonzero {
}
#[inline]
unsafe fn read_unchecked(stream: &mut BitReadStream<LittleEndian>) -> Result<Self> {
unsafe fn read_unchecked(
stream: &mut BitReadStream<LittleEndian>,
end: bool,
) -> Result<Self> {
Ok(<$type>::new(
stream.read_int_unchecked(size_of::<$type>() * 8),
stream.read_int_unchecked(size_of::<$type>() * 8, end),
))
}
@ -170,9 +173,12 @@ macro_rules! impl_read_int_nonzero {
}
#[inline]
unsafe fn read_unchecked(stream: &mut BitReadStream<BigEndian>) -> Result<Self> {
unsafe fn read_unchecked(
stream: &mut BitReadStream<BigEndian>,
end: bool,
) -> Result<Self> {
Ok(<$type>::new(
stream.read_int_unchecked(size_of::<$type>() * 8),
stream.read_int_unchecked(size_of::<$type>() * 8, end),
))
}
@ -208,8 +214,8 @@ impl<E: Endianness> BitRead<E> for f32 {
}
#[inline]
unsafe fn read_unchecked(stream: &mut BitReadStream<E>) -> Result<f32> {
Ok(stream.read_float_unchecked::<f32>())
unsafe fn read_unchecked(stream: &mut BitReadStream<E>, end: bool) -> Result<f32> {
Ok(stream.read_float_unchecked::<f32>(end))
}
#[inline]
@ -225,8 +231,8 @@ impl<E: Endianness> BitRead<E> for f64 {
}
#[inline]
unsafe fn read_unchecked(stream: &mut BitReadStream<E>) -> Result<f64> {
Ok(stream.read_float_unchecked::<f64>())
unsafe fn read_unchecked(stream: &mut BitReadStream<E>, end: bool) -> Result<f64> {
Ok(stream.read_float_unchecked::<f64>(end))
}
#[inline]
@ -242,7 +248,7 @@ impl<E: Endianness> BitRead<E> for bool {
}
#[inline]
unsafe fn read_unchecked(stream: &mut BitReadStream<E>) -> Result<bool> {
unsafe fn read_unchecked(stream: &mut BitReadStream<E>, _end: bool) -> Result<bool> {
Ok(stream.read_bool_unchecked())
}
@ -266,8 +272,8 @@ impl<E: Endianness, T: BitRead<E>> BitRead<E> for Rc<T> {
}
#[inline]
unsafe fn read_unchecked(stream: &mut BitReadStream<E>) -> Result<Self> {
Ok(Rc::new(T::read_unchecked(stream)?))
unsafe fn read_unchecked(stream: &mut BitReadStream<E>, end: bool) -> Result<Self> {
Ok(Rc::new(T::read_unchecked(stream, end)?))
}
#[inline]
@ -283,8 +289,8 @@ impl<E: Endianness, T: BitRead<E>> BitRead<E> for Arc<T> {
}
#[inline]
unsafe fn read_unchecked(stream: &mut BitReadStream<E>) -> Result<Self> {
Ok(Arc::new(T::read_unchecked(stream)?))
unsafe fn read_unchecked(stream: &mut BitReadStream<E>, end: bool) -> Result<Self> {
Ok(Arc::new(T::read_unchecked(stream, end)?))
}
#[inline]
@ -300,8 +306,8 @@ impl<E: Endianness, T: BitRead<E>> BitRead<E> for Box<T> {
}
#[inline]
unsafe fn read_unchecked(stream: &mut BitReadStream<E>) -> Result<Self> {
Ok(Box::new(T::read_unchecked(stream)?))
unsafe fn read_unchecked(stream: &mut BitReadStream<E>, end: bool) -> Result<Self> {
Ok(Box::new(T::read_unchecked(stream, end)?))
}
#[inline]
@ -319,8 +325,8 @@ macro_rules! impl_read_tuple {
}
#[inline]
unsafe fn read_unchecked(stream: &mut BitReadStream<E>) -> Result<Self> {
Ok(($(<$type>::read_unchecked(stream)?),*))
unsafe fn read_unchecked(stream: &mut BitReadStream<E>, end: bool) -> Result<Self> {
Ok(($(<$type>::read_unchecked(stream, end)?),*))
}
#[inline]
@ -405,7 +411,11 @@ pub trait BitReadSized<E: Endianness>: Sized {
#[doc(hidden)]
#[inline]
unsafe fn read_unchecked(stream: &mut BitReadStream<E>, size: usize) -> Result<Self> {
unsafe fn read_unchecked(
stream: &mut BitReadStream<E>,
size: usize,
_end: bool,
) -> Result<Self> {
Self::read(stream, size)
}
@ -437,8 +447,12 @@ macro_rules! impl_read_int_sized {
}
#[inline]
unsafe fn read_unchecked(stream: &mut BitReadStream<E>, size: usize) -> Result<$type> {
Ok(stream.read_int_unchecked::<$type>(size))
unsafe fn read_unchecked(
stream: &mut BitReadStream<E>,
size: usize,
end: bool,
) -> Result<$type> {
Ok(stream.read_int_unchecked::<$type>(size, end))
}
#[inline]
@ -516,10 +530,14 @@ impl<E: Endianness, T: BitRead<E>> BitReadSized<E> for Vec<T> {
}
#[inline]
unsafe fn read_unchecked(stream: &mut BitReadStream<E>, size: usize) -> Result<Self> {
unsafe fn read_unchecked(
stream: &mut BitReadStream<E>,
size: usize,
end: bool,
) -> Result<Self> {
let mut vec = Vec::with_capacity(min(size, 128));
for _ in 0..size {
vec.push(stream.read_unchecked()?)
vec.push(stream.read_unchecked(end)?)
}
Ok(vec)
}
@ -552,11 +570,15 @@ impl<E: Endianness, K: BitRead<E> + Eq + Hash, T: BitRead<E>> BitReadSized<E> fo
}
#[inline]
unsafe fn read_unchecked(stream: &mut BitReadStream<E>, size: usize) -> Result<Self> {
unsafe fn read_unchecked(
stream: &mut BitReadStream<E>,
size: usize,
end: bool,
) -> Result<Self> {
let mut map = HashMap::with_capacity(min(size, 128));
for _ in 0..size {
let key = stream.read_unchecked()?;
let value = stream.read_unchecked()?;
let key = stream.read_unchecked(end)?;
let value = stream.read_unchecked(end)?;
map.insert(key, value);
}
Ok(map)

View file

@ -14,6 +14,7 @@ use crate::{BitError, Result};
use std::convert::TryInto;
const USIZE_SIZE: usize = size_of::<usize>();
const USIZE_BIT_SIZE: usize = USIZE_SIZE * 8;
/// Buffer that allows reading integers of arbitrary bit length and non byte-aligned integers
///
@ -60,11 +61,9 @@ where
/// ];
/// let buffer = BitReadBuffer::new(bytes, LittleEndian);
/// ```
pub fn new(mut bytes: Vec<u8>, _endianness: E) -> Self {
pub fn new(bytes: Vec<u8>, _endianness: E) -> Self {
let byte_len = bytes.len();
// pad with usize worth of bytes to ensure we can always read a full usize
bytes.extend_from_slice(&0usize.to_le_bytes());
BitReadBuffer {
bytes: Rc::new(bytes),
bit_len: byte_len * 8,
@ -103,29 +102,37 @@ where
self.bytes.len()
}
unsafe fn read_usize_bytes(&self, byte_index: usize) -> [u8; USIZE_SIZE] {
debug_assert!(byte_index + USIZE_SIZE <= self.bytes.len());
// this is safe because all calling paths check that byte_index is less than the unpadded
// length (because they check based on bit_len), so with padding byte_index + USIZE_SIZE is
// always within bounds
self.bytes
.get_unchecked(byte_index..byte_index + USIZE_SIZE)
.try_into()
.unwrap()
unsafe fn read_usize_bytes(&self, byte_index: usize, end: bool) -> [u8; USIZE_SIZE] {
if end {
let mut bytes = [0; USIZE_SIZE];
let count = min(USIZE_SIZE, self.bytes.len() - byte_index);
bytes[0..count]
.copy_from_slice(self.bytes.get_unchecked(byte_index..byte_index + count));
bytes
} else {
debug_assert!(byte_index + USIZE_SIZE <= self.bytes.len());
// this is safe because all calling paths check that byte_index is less than the unpadded
// length (because they check based on bit_len), so with padding byte_index + USIZE_SIZE is
// always within bounds
self.bytes
.get_unchecked(byte_index..byte_index + USIZE_SIZE)
.try_into()
.unwrap()
}
}
/// note that only the bottom USIZE - 1 bytes are usable
unsafe fn read_shifted_usize(&self, byte_index: usize, shift: usize) -> usize {
let raw_bytes: [u8; USIZE_SIZE] = self.read_usize_bytes(byte_index);
unsafe fn read_shifted_usize(&self, byte_index: usize, shift: usize, end: bool) -> usize {
let raw_bytes: [u8; USIZE_SIZE] = self.read_usize_bytes(byte_index, end);
let raw_usize: usize = usize::from_le_bytes(raw_bytes);
raw_usize >> shift
}
unsafe fn read_usize(&self, position: usize, count: usize) -> usize {
unsafe fn read_usize(&self, position: usize, count: usize, end: bool) -> usize {
let byte_index = position / 8;
let bit_offset = position & 7;
let bytes: [u8; USIZE_SIZE] = self.read_usize_bytes(byte_index);
let bytes: [u8; USIZE_SIZE] = self.read_usize_bytes(byte_index, end);
let container = if E::is_le() {
usize::from_le_bytes(bytes)
@ -235,26 +242,31 @@ where
});
}
if position + count > self.bit_len() {
return if position > self.bit_len() {
Err(BitError::IndexOutOfBounds {
pos: position,
size: self.bit_len(),
})
} else {
Err(BitError::NotEnoughData {
requested: count,
bits_left: self.bit_len() - position,
})
};
}
let end = if position + count + USIZE_BIT_SIZE > self.bit_len() {
if position + count > self.bit_len() {
return if position > self.bit_len() {
Err(BitError::IndexOutOfBounds {
pos: position,
size: self.bit_len(),
})
} else {
Err(BitError::NotEnoughData {
requested: count,
bits_left: self.bit_len() - position,
})
};
}
true
} else {
false
};
Ok(unsafe { self.read_int_unchecked(position, count) })
Ok(unsafe { self.read_int_unchecked(position, count, end) })
}
#[doc(hidden)]
#[inline]
pub unsafe fn read_int_unchecked<T>(&self, position: usize, count: usize) -> T
pub unsafe fn read_int_unchecked<T>(&self, position: usize, count: usize, end: bool) -> T
where
T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt + BitXor,
{
@ -265,9 +277,9 @@ where
let fit_usize = count + bit_offset < usize_bit_size;
let value = if fit_usize {
self.read_fit_usize(position, count)
self.read_fit_usize(position, count, end)
} else {
self.read_no_fit_usize(position, count)
self.read_no_fit_usize(position, count, end)
};
if count == type_bit_size {
@ -278,15 +290,15 @@ where
}
#[inline]
unsafe fn read_fit_usize<T>(&self, position: usize, count: usize) -> T
unsafe fn read_fit_usize<T>(&self, position: usize, count: usize, end: bool) -> T
where
T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt,
{
let raw = self.read_usize(position, count);
let raw = self.read_usize(position, count, end);
T::from_unchecked(raw)
}
unsafe fn read_no_fit_usize<T>(&self, position: usize, count: usize) -> T
unsafe fn read_no_fit_usize<T>(&self, position: usize, count: usize, end: bool) -> T
where
T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt,
{
@ -298,7 +310,7 @@ where
while left_to_read > 0 {
let bits_left = self.bit_len() - read_pos;
let read = min(min(left_to_read, max_read), bits_left);
let data = T::from_unchecked(self.read_usize(read_pos, read));
let data = T::from_unchecked(self.read_usize(read_pos, read, end));
if E::is_le() {
acc |= data << bit_offset;
} else {
@ -392,7 +404,9 @@ where
let mut byte_left = byte_count;
let mut read_pos = position / 8;
while byte_left > USIZE_SIZE - 1 {
let bytes = self.read_shifted_usize(read_pos, shift).to_le_bytes();
let bytes = self
.read_shifted_usize(read_pos, shift, false)
.to_le_bytes();
let read_bytes = USIZE_SIZE - 1;
let usable_bytes = &bytes[0..read_bytes];
data.extend_from_slice(usable_bytes);
@ -401,7 +415,7 @@ where
byte_left -= read_bytes;
}
let bytes = self.read_shifted_usize(read_pos, shift).to_le_bytes();
let bytes = self.read_shifted_usize(read_pos, shift, true).to_le_bytes();
let usable_bytes = &bytes[0..byte_left];
data.extend_from_slice(usable_bytes);
@ -462,7 +476,7 @@ where
fn find_null_byte(&self, byte_index: usize) -> usize {
memchr::memchr(0, &self.bytes[byte_index..])
.map(|index| index + byte_index)
.unwrap() // due to padding we always have 0 bytes at the end
.unwrap_or(self.bytes.len()) // due to padding we always have 0 bytes at the end
}
#[inline]
@ -481,7 +495,7 @@ where
//
// This is safe because the final usize is filled with 0's, thus triggering the exit clause
// before reading any out of bounds
let shifted = unsafe { self.read_shifted_usize(byte_index, shift) };
let shifted = unsafe { self.read_shifted_usize(byte_index, shift, true) };
let has_null = contains_zero_byte_non_top(shifted);
let bytes: [u8; USIZE_SIZE] = shifted.to_le_bytes();
@ -533,38 +547,43 @@ where
T: Float + UncheckedPrimitiveFloat,
{
let type_bit_size = size_of::<T>() * 8;
if position + type_bit_size > self.bit_len() {
if position > self.bit_len() {
return Err(BitError::IndexOutOfBounds {
pos: position,
size: self.bit_len(),
});
} else {
return Err(BitError::NotEnoughData {
requested: size_of::<T>() * 8,
bits_left: self.bit_len() - position,
});
let end = if position + type_bit_size + USIZE_BIT_SIZE > self.bit_len() {
if position + type_bit_size > self.bit_len() {
if position > self.bit_len() {
return Err(BitError::IndexOutOfBounds {
pos: position,
size: self.bit_len(),
});
} else {
return Err(BitError::NotEnoughData {
requested: size_of::<T>() * 8,
bits_left: self.bit_len() - position,
});
}
}
}
true
} else {
false
};
Ok(unsafe { self.read_float_unchecked(position) })
Ok(unsafe { self.read_float_unchecked(position, end) })
}
#[doc(hidden)]
#[inline]
pub unsafe fn read_float_unchecked<T>(&self, position: usize) -> T
pub unsafe fn read_float_unchecked<T>(&self, position: usize, end: bool) -> T
where
T: Float + UncheckedPrimitiveFloat,
{
if size_of::<T>() == 4 {
let int = if size_of::<T>() < USIZE_SIZE {
self.read_fit_usize::<u32>(position, 32)
self.read_fit_usize::<u32>(position, 32, end)
} else {
self.read_no_fit_usize::<u32>(position, 32)
self.read_no_fit_usize::<u32>(position, 32, end)
};
T::from_f32_unchecked(f32::from_bits(int))
} else {
let int = self.read_no_fit_usize::<u64>(position, 64);
let int = self.read_no_fit_usize::<u64>(position, 64, end);
T::from_f64_unchecked(f64::from_bits(int))
}
}

View file

@ -150,11 +150,11 @@ where
#[doc(hidden)]
#[inline]
pub unsafe fn read_int_unchecked<T>(&mut self, count: usize) -> T
pub unsafe fn read_int_unchecked<T>(&mut self, count: usize, end: bool) -> T
where
T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt,
{
let result = self.buffer.read_int_unchecked(self.pos, count);
let result = self.buffer.read_int_unchecked(self.pos, count, end);
self.pos += count;
result
}
@ -200,12 +200,12 @@ where
#[doc(hidden)]
#[inline]
pub unsafe fn read_float_unchecked<T>(&mut self) -> T
pub unsafe fn read_float_unchecked<T>(&mut self, end: bool) -> T
where
T: Float + UncheckedPrimitiveFloat,
{
let count = size_of::<T>() * 8;
let result = self.buffer.read_float_unchecked(self.pos);
let result = self.buffer.read_float_unchecked(self.pos, end);
self.pos += count;
result
}
@ -585,8 +585,8 @@ where
#[doc(hidden)]
#[inline]
pub unsafe fn read_unchecked<T: BitRead<E>>(&mut self) -> Result<T> {
T::read_unchecked(self)
pub unsafe fn read_unchecked<T: BitRead<E>>(&mut self, end: bool) -> Result<T> {
T::read_unchecked(self, end)
}
/// Read a value based on the provided type and size
@ -635,19 +635,27 @@ where
#[doc(hidden)]
#[inline]
pub unsafe fn read_sized_unchecked<T: BitReadSized<E>>(&mut self, size: usize) -> Result<T> {
T::read_unchecked(self, size)
pub unsafe fn read_sized_unchecked<T: BitReadSized<E>>(
&mut self,
size: usize,
end: bool,
) -> Result<T> {
T::read_unchecked(self, size, end)
}
/// Check if we can read a number of bits from the stream
pub fn check_read(&self, count: usize) -> Result<()> {
if self.bits_left() < count {
Err(BitError::NotEnoughData {
requested: count,
bits_left: self.bits_left(),
})
pub fn check_read(&self, count: usize) -> Result<bool> {
if self.bits_left() < count + 64 {
if self.bits_left() < count {
Err(BitError::NotEnoughData {
requested: count,
bits_left: self.bits_left(),
})
} else {
Ok(true)
}
} else {
Ok(())
Ok(false)
}
}
}

View file

@ -300,18 +300,18 @@ fn read_trait_unchecked() {
unsafe {
let buffer = BitReadBuffer::new(BYTES.to_vec(), BigEndian);
let mut stream = BitReadStream::new(buffer);
let a: u8 = stream.read_unchecked().unwrap();
let a: u8 = stream.read_unchecked(true).unwrap();
assert_eq!(0b1011_0101, a);
let b: i8 = stream.read_unchecked().unwrap();
let b: i8 = stream.read_unchecked(true).unwrap();
assert_eq!(0b110_1010, b);
let c: i16 = stream.read_unchecked().unwrap();
let c: i16 = stream.read_unchecked(true).unwrap();
assert_eq!(-0b101_0011_0110_0111, c);
let d: bool = stream.read_unchecked().unwrap();
let d: bool = stream.read_unchecked(true).unwrap();
assert_eq!(true, d);
let e: Option<u8> = stream.read_unchecked().unwrap();
let e: Option<u8> = stream.read_unchecked(true).unwrap();
assert_eq!(None, e);
stream.set_pos(0).unwrap();
let f: Option<u8> = stream.read_unchecked().unwrap();
let f: Option<u8> = stream.read_unchecked(true).unwrap();
assert_eq!(Some(0b011_0101_0), f);
}
}
@ -351,10 +351,10 @@ fn read_sized_trait_unchecked() {
unsafe {
let buffer = BitReadBuffer::new(BYTES.to_vec(), BigEndian);
let mut stream = BitReadStream::new(buffer);
let a: u8 = stream.read_sized_unchecked(4).unwrap();
let a: u8 = stream.read_sized_unchecked(4, true).unwrap();
assert_eq!(0b1011, a);
stream.set_pos(0).unwrap();
let vec: Vec<u16> = stream.read_sized_unchecked(3).unwrap();
let vec: Vec<u16> = stream.read_sized_unchecked(3, true).unwrap();
assert_eq!(
vec![
0b1011_0101_0110_1010,
@ -364,16 +364,16 @@ fn read_sized_trait_unchecked() {
vec
);
stream.set_pos(0).unwrap();
let vec: Vec<u8> = stream.read_sized_unchecked(3).unwrap();
let vec: Vec<u8> = stream.read_sized_unchecked(3, true).unwrap();
assert_eq!(vec![0b1011_0101, 0b0110_1010, 0b1010_1100], vec);
stream.set_pos(0).unwrap();
let result: HashMap<u8, u8> = stream.read_sized_unchecked(2).unwrap();
let result: HashMap<u8, u8> = stream.read_sized_unchecked(2, true).unwrap();
assert_eq!(
hashmap!(0b1011_0101 => 0b0110_1010, 0b1010_1100 => 0b1001_1001),
result
);
stream.set_pos(0).unwrap();
let mut result: BitReadStream<BigEndian> = stream.read_sized_unchecked(4).unwrap();
let mut result: BitReadStream<BigEndian> = stream.read_sized_unchecked(4, true).unwrap();
assert_eq!(0b10u8, result.read_int(2).unwrap());
}
}