1
0
Fork 0
mirror of https://codeberg.org/icewind/bitbuffer.git synced 2026-06-03 16:44:06 +02:00

allow deriving BitRead for enums

This commit is contained in:
Robin Appelman 2019-02-28 19:05:23 +01:00
commit 8ad8d07dd5
3 changed files with 130 additions and 14 deletions

View file

@ -27,16 +27,13 @@ extern crate proc_macro;
use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned};
use syn::{Attribute, Data, DeriveInput, Expr, Field, Fields, GenericParam, Generics, Ident, Lit, Meta, parse_macro_input, parse_quote, LitStr};
use syn::spanned::Spanned;
use syn::{
parse_macro_input, parse_quote, Data, DeriveInput, Field, Fields, GenericParam, Generics,
Ident, Lit, Meta,
};
/// See the [crate documentation](index.html) for details
#[proc_macro_derive(BitRead, attributes(size, size_bits))]
#[proc_macro_derive(BitRead, attributes(size, size_bits, discriminant_bits, discriminant))]
pub fn derive_helper_attr(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let input: DeriveInput = parse_macro_input!(input as DeriveInput);
let name = input.ident;
@ -49,7 +46,7 @@ pub fn derive_helper_attr(input: proc_macro::TokenStream) -> proc_macro::TokenSt
.push(parse_quote!(_E: ::bitstream_reader::Endianness));
let (impl_generics, _, _) = trait_generics.split_for_impl();
let parse = parse(&input.data, &name);
let parse = parse(&input.data, &name, &input.attrs);
let expanded = quote! {
impl #impl_generics ::bitstream_reader::BitRead<_E> for #name #ty_generics #where_clause {
@ -76,7 +73,7 @@ fn add_trait_bounds(mut generics: Generics) -> Generics {
generics
}
fn parse(data: &Data, struct_name: &Ident) -> TokenStream {
fn parse(data: &Data, struct_name: &Ident, attrs: &Vec<Attribute>) -> TokenStream {
match *data {
Data::Struct(ref data) => {
match data.fields {
@ -119,13 +116,74 @@ fn parse(data: &Data, struct_name: &Ident) -> TokenStream {
_ => unimplemented!(),
}
}
Data::Enum(ref data) => {
let discriminant_bits = match get_attr(attrs, "discriminant_bits") {
Some(bits_lit) => match bits_lit {
Lit::Int(bits) => bits.value(),
_ => panic!("'discriminant_bits' attribute is required to be an integer literal")
},
None => panic!("'discriminant_bits' attribute is required when deriving `BinRead` for enums")
};
let discriminant_read = quote! {
let discriminant:usize = stream.read_int(#discriminant_bits as usize)?;
};
let mut last_discriminant = -1;
let mut discriminants = Vec::with_capacity(data.variants.len());
for variant in &data.variants {
let discriminant = variant.discriminant.clone()
.map(|(_, expr)| match expr {
Expr::Lit(expr_lit) => expr_lit.lit,
_ => panic!("discriminant is required to be an integer literal")
})
.or_else(|| get_attr(&variant.attrs, "discriminant"))
.map(|lit| match lit {
Lit::Int(lit) => lit.value(),
_ => panic!("discriminant is required to be an integer literal")
})
.unwrap_or_else(|| {
(last_discriminant + 1) as u64
}) as usize;
last_discriminant = discriminant as isize;
discriminants.push(discriminant)
}
let match_arms = data.variants.iter().zip(discriminants.iter()).map(|(variant, discriminant)| {
let span = variant.span();
let variant_name = &variant.ident;
let read_fields = match &variant.fields {
Fields::Unit => quote_spanned! {span=>
#struct_name::#variant_name
},
Fields::Unnamed(f) => quote_spanned! {span=>
#struct_name::#variant_name(stream.read()?)
},
_ => unimplemented!()
};
quote_spanned! {span=>
#discriminant => #read_fields,
}
});
let span = data.enum_token.span();
let enum_name = Lit::Str(LitStr::new(&struct_name.to_string(), struct_name.span()));
quote_spanned! {span=>
#discriminant_read
Ok(match discriminant {
#(#match_arms)*
_ => {
return Err(::bitstream_reader::ReadError::UnmatchedDiscriminant{discriminant, enum_name: #enum_name.to_string()})
}
})
}
}
_ => unimplemented!(),
}
}
fn get_field_size(field: &Field) -> Option<TokenStream> {
let span = field.span();
get_field_attr(field, "size")
get_attr(&field.attrs, "size")
.map(|size_lit| match size_lit {
Lit::Int(size) => {
quote_spanned! {span=>
@ -141,7 +199,7 @@ fn get_field_size(field: &Field) -> Option<TokenStream> {
_ => panic!("Unsupported value for size attribute"),
})
.or_else(|| {
get_field_attr(field, "size_bits").map(|size_bits_lit| {
get_attr(&field.attrs, "size_bits").map(|size_bits_lit| {
quote_spanned! {span=>
stream.read_int::<usize>(#size_bits_lit)?
}
@ -149,8 +207,8 @@ fn get_field_size(field: &Field) -> Option<TokenStream> {
})
}
fn get_field_attr(field: &Field, name: &str) -> Option<Lit> {
for attr in field.attrs.iter() {
fn get_attr(attrs: &Vec<Attribute>, name: &str) -> Option<Lit> {
for attr in attrs.iter() {
let meta = attr.parse_meta().unwrap();
match meta {
Meta::NameValue(ref name_value) if name_value.ident == name => {
@ -160,4 +218,4 @@ fn get_field_attr(field: &Field, name: &str) -> Option<Lit> {
}
}
None
}
}

View file

@ -1,4 +1,4 @@
use bitstream_reader::{BitBuffer, BitStream, LittleEndian};
use bitstream_reader::{BigEndian, BitBuffer, BitStream, LittleEndian, ReadError};
use bitstream_reader_derive::BitRead;
#[derive(BitRead, PartialEq, Debug)]
@ -55,3 +55,52 @@ fn test_read_struct() {
stream.read().unwrap()
);
}
#[derive(BitRead, PartialEq, Debug)]
#[discriminant_bits = 2]
enum TestBareEnum {
Foo,
Bar,
Asd = 3,
}
#[test]
fn test_read_bare_enum() {
let bytes = vec![
0b1100_0110, 0b1000_0100, 0b1000_0100, 0b1000_0100,
0b1000_0100, 0b1000_0100, 0b1000_0100, 0b1000_0100,
];
let buffer = BitBuffer::new(bytes, BigEndian);
let mut stream = BitStream::from(buffer);
assert_eq!(TestBareEnum::Asd, stream.read().unwrap());
assert_eq!(TestBareEnum::Foo, stream.read().unwrap());
assert_eq!(TestBareEnum::Bar, stream.read().unwrap());
assert_eq!(true, stream.read::<TestBareEnum>().is_err());
}
#[derive(BitRead, PartialEq, Debug)]
#[discriminant_bits = 2]
enum TestUnnamedFieldEnum {
Foo(i8),
Bar(bool),
#[discriminant = 3]
Asd(u8),
}
#[test]
fn test_read_unnamed_field_enum() {
let bytes = vec![
0b1100_0110, 0b1000_0100, 0b1000_0100, 0b1000_0100,
0b1000_0100, 0b1000_0100, 0b1000_0100, 0b1000_0100,
];
let buffer = BitBuffer::new(bytes, BigEndian);
let mut stream = BitStream::from(buffer);
assert_eq!(TestUnnamedFieldEnum::Asd(0b_00_0110_10), stream.read().unwrap());
assert_eq!(10, stream.pos());
stream.set_pos(2).unwrap();
assert_eq!(TestUnnamedFieldEnum::Foo(0b11_0_1000), stream.read().unwrap());
assert_eq!(12, stream.pos());
stream.set_pos(4).unwrap();
assert_eq!(TestUnnamedFieldEnum::Bar(true), stream.read().unwrap());
assert_eq!(7, stream.pos());
}

View file

@ -85,6 +85,13 @@ pub enum ReadError {
/// the number of bits in the buffer
size: usize,
},
/// Unmatched discriminant found while trying to read an enum
UnmatchedDiscriminant {
/// The read discriminant
discriminant: usize,
/// The name of the enum that is trying to be read
enum_name: String,
},
/// The read slice of bytes are not valid utf8
Utf8Error(FromUtf8Error),
}
@ -98,6 +105,8 @@ impl Display for ReadError {
write!(f, "Not enough data in the buffer to read all requested bits, requested to read {} bits while only {} bits are left", requested, bits_left),
ReadError::IndexOutOfBounds { pos, size } =>
write!(f, "The requested position is outside the bounds of the stream, requested position {} while the stream or buffer is only {} bits long", pos, size),
ReadError::UnmatchedDiscriminant { discriminant, enum_name } =>
write!(f, "Unmatched discriminant '{}' found while trying to read enum '{}'", discriminant, enum_name),
ReadError::Utf8Error(err) => err.fmt(f)
}
}