1
0
Fork 0
mirror of https://codeberg.org/icewind/bitbuffer.git synced 2026-06-03 08:34:07 +02:00

cleanup serde

This commit is contained in:
Robin Appelman 2021-07-25 14:25:56 +02:00
commit 9c3d98dc62
2 changed files with 83 additions and 70 deletions

View file

@ -815,7 +815,21 @@ impl<E: Endianness> Debug for BitReadBuffer<'_, E> {
impl<'a, E: Endianness> PartialEq for BitReadBuffer<'a, E> {
fn eq(&self, other: &Self) -> bool {
self.bit_len == other.bit_len && self.slice == other.slice
if self.bit_len != other.bit_len {
return false;
}
if self.bit_len % 8 == 0 {
self.slice == other.slice
} else {
let bytes = self.bit_len / 8;
let bits_left = self.bit_len % 8;
if self.slice[0..bytes] != other.slice[0..bytes] {
return false;
}
let rest_self = self.read_int::<u8>(bytes * 8, bits_left).unwrap();
let rest_other = other.read_int::<u8>(bytes * 8, bits_left).unwrap();
rest_self == rest_other
}
}
}
@ -840,3 +854,61 @@ fn contains_zero_byte_non_top(x: usize) -> bool {
x.wrapping_sub(LO_USIZE) & !x & HI_USIZE != 0
}
#[cfg(feature = "serde")]
use serde::{de, ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer};
#[cfg(feature = "serde")]
impl<'a, E: Endianness> Serialize for BitReadBuffer<'a, E> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut data = self.read_bytes(0, self.bit_len() / 8).unwrap().to_vec();
let bits_left = self.bit_len() % 8;
if bits_left > 0 {
data.push(self.read_int((self.bit_len() / 8) * 8, bits_left).unwrap());
}
let mut s = serializer.serialize_struct("BitReadBuffer", 3)?;
s.serialize_field("data", &data)?;
s.serialize_field("bit_length", &self.bit_len())?;
s.end()
}
}
#[cfg(feature = "serde")]
impl<'de, E: Endianness> Deserialize<'de> for BitReadBuffer<'static, E> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct BitData {
data: Vec<u8>,
bit_length: usize,
}
let data = BitData::deserialize(deserializer)?;
let mut buffer = BitReadBuffer::new_owned(data.data, E::endianness());
buffer
.truncate(data.bit_length)
.map_err(de::Error::custom)?;
Ok(buffer)
}
}
#[cfg(feature = "serde")]
#[test]
fn test_serde_roundtrip() {
use crate::LittleEndian;
let mut buffer = BitReadBuffer::new_owned(vec![55; 8], LittleEndian);
buffer.truncate(61).unwrap();
let json = serde_json::to_string(&buffer).unwrap();
let result: BitReadBuffer<LittleEndian> = serde_json::from_str(&json).unwrap();
assert_eq!(result, buffer);
}

View file

@ -745,11 +745,7 @@ impl<'a, E: Endianness> From<&'a [u8]> for BitReadStream<'a, E> {
}
#[cfg(feature = "serde")]
use serde::{
de::{self, MapAccess, SeqAccess, Visitor},
ser::SerializeStruct,
Deserialize, Deserializer, Serialize, Serializer,
};
use serde::{de, ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer};
#[cfg(feature = "serde")]
impl<'a, E: Endianness> Serialize for BitReadStream<'a, E> {
@ -777,70 +773,17 @@ impl<'de, E: Endianness> Deserialize<'de> for BitReadStream<'static, E> {
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "snake_case")]
enum Field {
Data,
BitLength,
struct BitData {
data: Vec<u8>,
bit_length: usize,
}
use std::marker::PhantomData;
struct ReadStreamVisitor<E>(PhantomData<E>);
impl<'de, E: Endianness> Visitor<'de> for ReadStreamVisitor<E> {
type Value = BitReadStream<'static, E>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("struct BitReadStream")
}
fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
where
V: SeqAccess<'de>,
{
let data = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(0, &self))?;
let bit_length = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &self))?;
let mut buffer = BitReadBuffer::new_owned(data, E::endianness());
buffer.truncate(bit_length).map_err(de::Error::custom)?;
Ok(BitReadStream::new(buffer))
}
fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
where
V: MapAccess<'de>,
{
let mut data = None;
let mut bit_length = None;
while let Some(key) = map.next_key()? {
match key {
Field::Data => {
if data.is_some() {
return Err(de::Error::duplicate_field("secs"));
}
data = Some(map.next_value()?);
}
Field::BitLength => {
if bit_length.is_some() {
return Err(de::Error::duplicate_field("nanos"));
}
bit_length = Some(map.next_value()?);
}
}
}
let data = data.ok_or_else(|| de::Error::missing_field("data"))?;
let bit_length =
bit_length.ok_or_else(|| de::Error::missing_field("bit_length"))?;
let mut buffer = BitReadBuffer::new_owned(data, E::endianness());
buffer.truncate(bit_length).map_err(de::Error::custom)?;
Ok(BitReadStream::new(buffer))
}
}
const FIELDS: &'static [&'static str] = &["data", "bit_length"];
deserializer.deserialize_struct("BitReadStream", FIELDS, ReadStreamVisitor(PhantomData))
let data = BitData::deserialize(deserializer)?;
let mut buffer = BitReadBuffer::new_owned(data.data, E::endianness());
buffer
.truncate(data.bit_length)
.map_err(de::Error::custom)?;
Ok(BitReadStream::new(buffer))
}
}
@ -856,8 +799,6 @@ fn test_serde_roundtrip() {
let json = serde_json::to_string(&stream).unwrap();
dbg!(&json);
let result: BitReadStream<LittleEndian> = serde_json::from_str(&json).unwrap();
assert_eq!(result, stream);