1
0
Fork 0
mirror of https://codeberg.org/icewind/bitbuffer.git synced 2026-06-03 16:44:06 +02:00

remove extra bounds check in stream

this is done by
 - making the buffer cloneable
 - allow getting a clone of the buffer with a shorter length
 - use the shorter buffer when reading a sub-stream

that way the bounds checks in the buffer are enough to bounds check sub-streams
This commit is contained in:
Robin Appelman 2019-03-03 14:06:17 +01:00
commit 383376f5f0
2 changed files with 61 additions and 25 deletions

View file

@ -4,6 +4,7 @@ use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem::size_of; use std::mem::size_of;
use std::ops::BitOrAssign; use std::ops::BitOrAssign;
use std::rc::Rc;
use num_traits::{Float, PrimInt}; use num_traits::{Float, PrimInt};
@ -37,7 +38,7 @@ pub struct BitBuffer<E>
where where
E: Endianness, E: Endianness,
{ {
bytes: Vec<u8>, bytes: Rc<Vec<u8>>,
bit_len: usize, bit_len: usize,
byte_len: usize, byte_len: usize,
endianness: PhantomData<E>, endianness: PhantomData<E>,
@ -63,7 +64,7 @@ where
pub fn new(bytes: Vec<u8>, _endianness: E) -> Self { pub fn new(bytes: Vec<u8>, _endianness: E) -> Self {
let byte_len = bytes.len(); let byte_len = bytes.len();
BitBuffer { BitBuffer {
bytes, bytes: Rc::new(bytes),
byte_len, byte_len,
bit_len: byte_len * 8, bit_len: byte_len * 8,
endianness: PhantomData, endianness: PhantomData,
@ -439,13 +440,54 @@ where
Ok(T::from_f64_unchecked(f64::from_bits(int))) Ok(T::from_f64_unchecked(f64::from_bits(int)))
} }
} }
/// Get a clone of the buffer with a shorter length
///
/// # Errors
///
/// - [`ReadError::NotEnoughData`]: if the requested length is higher than the buffer length
///
/// # Examples
///
/// ```
/// # use bitstream_reader::{BitBuffer, LittleEndian, Result};
/// #
/// # fn main() -> Result<()> {
/// # let bytes = vec![
/// # 0b1011_0101, 0b0110_1010, 0b1010_1100, 0b1001_1001,
/// # 0b1001_1001, 0b1001_1001, 0b1001_1001, 0b1110_0111
/// # ];
/// # let buffer = BitBuffer::new(bytes, LittleEndian);
/// let sub = buffer.get_sub_buffer(16)?;
/// let result: u8 = sub.read_int(0, 6)?;
/// #
/// # Ok(())
/// # }
/// ```
///
/// [`ReadError::NotEnoughData`]: enum.ReadError.html#variant.NotEnoughData
pub fn get_sub_buffer(&self, bit_len: usize) -> Result<Self> {
if bit_len > self.bit_len {
return Err(ReadError::NotEnoughData {
requested: bit_len,
bits_left: self.bit_len,
});
}
Ok(BitBuffer {
bytes: Rc::clone(&self.bytes),
byte_len: bit_len / 8,
bit_len,
endianness: PhantomData,
})
}
} }
impl<E: Endianness> From<Vec<u8>> for BitBuffer<E> { impl<E: Endianness> From<Vec<u8>> for BitBuffer<E> {
fn from(bytes: Vec<u8>) -> Self { fn from(bytes: Vec<u8>) -> Self {
let byte_len = bytes.len(); let byte_len = bytes.len();
BitBuffer { BitBuffer {
bytes, bytes: Rc::new(bytes),
byte_len, byte_len,
bit_len: byte_len * 8, bit_len: byte_len * 8,
endianness: PhantomData, endianness: PhantomData,
@ -453,6 +495,17 @@ impl<E: Endianness> From<Vec<u8>> for BitBuffer<E> {
} }
} }
impl<E: Endianness> Clone for BitBuffer<E> {
fn clone(&self) -> Self {
BitBuffer {
bytes: Rc::clone(&self.bytes),
byte_len: self.byte_len,
bit_len: self.bit_len,
endianness: PhantomData,
}
}
}
impl<E: Endianness> Debug for BitBuffer<E> { impl<E: Endianness> Debug for BitBuffer<E> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!( write!(

View file

@ -1,6 +1,5 @@
use std::mem::size_of; use std::mem::size_of;
use std::ops::BitOrAssign; use std::ops::BitOrAssign;
use std::rc::Rc;
use num_traits::{Float, PrimInt}; use num_traits::{Float, PrimInt};
@ -31,7 +30,7 @@ pub struct BitStream<E>
where where
E: Endianness, E: Endianness,
{ {
buffer: Rc<BitBuffer<E>>, buffer: BitBuffer<E>,
start_pos: usize, start_pos: usize,
pos: usize, pos: usize,
bit_len: usize, bit_len: usize,
@ -62,18 +61,7 @@ where
start_pos: 0, start_pos: 0,
pos: 0, pos: 0,
bit_len: buffer.bit_len(), bit_len: buffer.bit_len(),
buffer: Rc::new(buffer), buffer,
}
}
fn verify_bits_left(&self, count: usize) -> Result<()> {
if self.bits_left() < count {
Err(ReadError::NotEnoughData {
bits_left: self.bits_left(),
requested: count,
})
} else {
Ok(())
} }
} }
@ -105,7 +93,6 @@ where
/// ///
/// [`ReadError::NotEnoughData`]: enum.ReadError.html#variant.NotEnoughData /// [`ReadError::NotEnoughData`]: enum.ReadError.html#variant.NotEnoughData
pub fn read_bool(&mut self) -> Result<bool> { pub fn read_bool(&mut self) -> Result<bool> {
self.verify_bits_left(1)?;
let result = self.buffer.read_bool(self.pos); let result = self.buffer.read_bool(self.pos);
if result.is_ok() { if result.is_ok() {
self.pos += 1; self.pos += 1;
@ -146,7 +133,6 @@ where
where where
T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt, T: PrimInt + BitOrAssign + IsSigned + UncheckedPrimitiveInt,
{ {
self.verify_bits_left(count)?;
let result = self.buffer.read_int(self.pos, count); let result = self.buffer.read_int(self.pos, count);
if result.is_ok() { if result.is_ok() {
self.pos += count; self.pos += count;
@ -185,7 +171,6 @@ where
T: Float + UncheckedPrimitiveFloat, T: Float + UncheckedPrimitiveFloat,
{ {
let count = size_of::<T>() * 8; let count = size_of::<T>() * 8;
self.verify_bits_left(count)?;
let result = self.buffer.read_float(self.pos); let result = self.buffer.read_float(self.pos);
if result.is_ok() { if result.is_ok() {
self.pos += count; self.pos += count;
@ -221,7 +206,6 @@ where
/// [`ReadError::NotEnoughData`]: enum.ReadError.html#variant.NotEnoughData /// [`ReadError::NotEnoughData`]: enum.ReadError.html#variant.NotEnoughData
pub fn read_bytes(&mut self, byte_count: usize) -> Result<Vec<u8>> { pub fn read_bytes(&mut self, byte_count: usize) -> Result<Vec<u8>> {
let count = byte_count * 8; let count = byte_count * 8;
self.verify_bits_left(count)?;
let result = self.buffer.read_bytes(self.pos, byte_count); let result = self.buffer.read_bytes(self.pos, byte_count);
if result.is_ok() { if result.is_ok() {
self.pos += count; self.pos += count;
@ -305,6 +289,7 @@ where
/// assert_eq!(bits.bit_len(), 3); /// assert_eq!(bits.bit_len(), 3);
/// assert_eq!(stream.read_int::<u8>(3)?, 0b110); /// assert_eq!(stream.read_int::<u8>(3)?, 0b110);
/// assert_eq!(bits.read_int::<u8>(3)?, 0b101); /// assert_eq!(bits.read_int::<u8>(3)?, 0b101);
/// assert_eq!(true, bits.read_int::<u8>(1).is_err());
/// # /// #
/// # Ok(()) /// # Ok(())
/// # } /// # }
@ -312,9 +297,8 @@ where
/// ///
/// [`ReadError::NotEnoughData`]: enum.ReadError.html#variant.NotEnoughData /// [`ReadError::NotEnoughData`]: enum.ReadError.html#variant.NotEnoughData
pub fn read_bits(&mut self, count: usize) -> Result<Self> { pub fn read_bits(&mut self, count: usize) -> Result<Self> {
self.verify_bits_left(count)?;
let result = BitStream { let result = BitStream {
buffer: Rc::clone(&self.buffer), buffer: self.buffer.get_sub_buffer(self.pos + count)?,
start_pos: self.pos, start_pos: self.pos,
pos: self.pos, pos: self.pos,
bit_len: count, bit_len: count,
@ -351,7 +335,6 @@ where
/// ///
/// [`ReadError::NotEnoughData`]: enum.ReadError.html#variant.NotEnoughData /// [`ReadError::NotEnoughData`]: enum.ReadError.html#variant.NotEnoughData
pub fn skip(&mut self, count: usize) -> Result<()> { pub fn skip(&mut self, count: usize) -> Result<()> {
self.verify_bits_left(count)?;
self.pos += count; self.pos += count;
Ok(()) Ok(())
} }
@ -523,7 +506,7 @@ where
impl<E: Endianness> Clone for BitStream<E> { impl<E: Endianness> Clone for BitStream<E> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
BitStream { BitStream {
buffer: Rc::clone(&self.buffer), buffer: self.buffer.clone(),
start_pos: self.pos, start_pos: self.pos,
pos: self.pos, pos: self.pos,
bit_len: self.bit_len, bit_len: self.bit_len,