1
0
Fork 0
mirror of https://codeberg.org/icewind/bitbuffer.git synced 2026-06-03 16:44:06 +02:00

Add #[align] attribute to structs

This attribute aligns the reader to byte boundary

It can be applied to enums & structs to align before reading any fields or the discriminant

It can also be applied to individual struct fields to align the reader before reading the field

Finally, you can apply it to non-unit enum variants to align the reader after reading the discriminant, but before reading the payload
This commit is contained in:
Nikita Strygin 2023-02-06 15:49:33 +03:00
commit e33aa2f776
2 changed files with 139 additions and 13 deletions

View file

@ -146,12 +146,12 @@ use syn::{
parse_macro_input, parse_quote, parse_str, Attribute, Data, DataStruct, DeriveInput, Expr, parse_macro_input, parse_quote, parse_str, Attribute, Data, DataStruct, DeriveInput, Expr,
Fields, GenericParam, Ident, Lit, LitInt, LitStr, Path, Fields, GenericParam, Ident, Lit, LitInt, LitStr, Path,
}; };
use syn_util::get_attribute_value; use syn_util::{contains_attribute, get_attribute_value};
/// See the [crate documentation](index.html) for details /// See the [crate documentation](index.html) for details
#[proc_macro_derive( #[proc_macro_derive(
BitRead, BitRead,
attributes(size, size_bits, discriminant_bits, discriminant, endianness) attributes(size, size_bits, discriminant_bits, discriminant, endianness, align)
)] )]
pub fn derive_bitread(input: proc_macro::TokenStream) -> proc_macro::TokenStream { pub fn derive_bitread(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
derive_bitread_trait(input, "BitRead".to_owned(), None) derive_bitread_trait(input, "BitRead".to_owned(), None)
@ -161,7 +161,7 @@ pub fn derive_bitread(input: proc_macro::TokenStream) -> proc_macro::TokenStream
/// See the [crate documentation](index.html) for details /// See the [crate documentation](index.html) for details
#[proc_macro_derive( #[proc_macro_derive(
BitReadSized, BitReadSized,
attributes(size, size_bits, discriminant_bits, discriminant, endianness) attributes(size, size_bits, discriminant_bits, discriminant, endianness, align)
)] )]
pub fn derive_bitread_sized(input: proc_macro::TokenStream) -> proc_macro::TokenStream { pub fn derive_bitread_sized(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let extra_param = parse_str::<TokenStream>(", input_size: usize").unwrap(); let extra_param = parse_str::<TokenStream>(", input_size: usize").unwrap();
@ -171,7 +171,7 @@ pub fn derive_bitread_sized(input: proc_macro::TokenStream) -> proc_macro::Token
/// See the [crate documentation](index.html) for details /// See the [crate documentation](index.html) for details
#[proc_macro_derive( #[proc_macro_derive(
BitWrite, BitWrite,
attributes(size, size_bits, discriminant_bits, discriminant, endianness) attributes(size, size_bits, discriminant_bits, discriminant, endianness, align)
)] )]
pub fn derive_bitwrite(input: proc_macro::TokenStream) -> proc_macro::TokenStream { pub fn derive_bitwrite(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
derive_bitwrite_trait(input, "BitWrite".into(), "write".into(), None) derive_bitwrite_trait(input, "BitWrite".into(), "write".into(), None)
@ -181,7 +181,7 @@ pub fn derive_bitwrite(input: proc_macro::TokenStream) -> proc_macro::TokenStrea
/// See the [crate documentation](index.html) for details /// See the [crate documentation](index.html) for details
#[proc_macro_derive( #[proc_macro_derive(
BitWriteSized, BitWriteSized,
attributes(size, size_bits, discriminant_bits, discriminant, endianness) attributes(size, size_bits, discriminant_bits, discriminant, endianness, align)
)] )]
pub fn derive_bitwrite_sized(input: proc_macro::TokenStream) -> proc_macro::TokenStream { pub fn derive_bitwrite_sized(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let extra_param = parse_str::<TokenStream>(", input_size: usize").unwrap(); let extra_param = parse_str::<TokenStream>(", input_size: usize").unwrap();
@ -299,11 +299,14 @@ fn derive_bitread_trait(
fn parse(data: Data, struct_name: &Ident, attrs: &[Attribute], unchecked: bool) -> TokenStream { fn parse(data: Data, struct_name: &Ident, attrs: &[Attribute], unchecked: bool) -> TokenStream {
let span = struct_name.span(); let span = struct_name.span();
let align = get_align(attrs);
match data { match data {
Data::Struct(DataStruct { fields, .. }) => { Data::Struct(DataStruct { fields, .. }) => {
let values = fields.iter().map(|f| { let values = fields.iter().map(|f| {
// Get attributes `#[..]` on each field // Get attributes `#[..]` on each field
let size = get_field_size(&f.attrs, f.span()); let size = get_field_size(&f.attrs, f.span());
let align = get_align(attrs);
let field_type = &f.ty; let field_type = &f.ty;
let span = f.span(); let span = f.span();
if unchecked { if unchecked {
@ -311,6 +314,7 @@ fn parse(data: Data, struct_name: &Ident, attrs: &[Attribute], unchecked: bool)
Some(size) => { Some(size) => {
quote_spanned! { span => quote_spanned! { span =>
{ {
#align;
let _size: usize = #size; let _size: usize = #size;
stream.read_sized_unchecked::<#field_type>(_size, end)? stream.read_sized_unchecked::<#field_type>(_size, end)?
} }
@ -318,15 +322,19 @@ fn parse(data: Data, struct_name: &Ident, attrs: &[Attribute], unchecked: bool)
} }
None => { None => {
quote_spanned! { span => quote_spanned! { span =>
{
#align;
stream.read_unchecked::<#field_type>(end)? stream.read_unchecked::<#field_type>(end)?
} }
} }
} }
}
} else { } else {
match size { match size {
Some(size) => { Some(size) => {
quote_spanned! { span => quote_spanned! { span =>
{ {
#align;
let _size: usize = #size; let _size: usize = #size;
stream.read_sized::<#field_type>(_size)? stream.read_sized::<#field_type>(_size)?
} }
@ -334,11 +342,14 @@ fn parse(data: Data, struct_name: &Ident, attrs: &[Attribute], unchecked: bool)
} }
None => { None => {
quote_spanned! { span => quote_spanned! { span =>
{
#align;
stream.read::<#field_type>()? stream.read::<#field_type>()?
} }
} }
} }
} }
}
}); });
match &fields { match &fields {
@ -356,6 +367,8 @@ fn parse(data: Data, struct_name: &Ident, attrs: &[Attribute], unchecked: bool)
} }
}); });
quote_spanned! { span => quote_spanned! { span =>
#align;
#(#definitions)* #(#definitions)*
Ok(#struct_name { Ok(#struct_name {
@ -364,11 +377,15 @@ fn parse(data: Data, struct_name: &Ident, attrs: &[Attribute], unchecked: bool)
} }
} }
Fields::Unnamed(_) => quote_spanned! { span => Fields::Unnamed(_) => quote_spanned! { span =>
#align;
Ok(#struct_name( Ok(#struct_name(
#(#values ,)* #(#values ,)*
)) ))
}, },
Fields::Unit => quote_spanned! {span=> Fields::Unit => quote_spanned! { span=>
#align;
Ok(#struct_name) Ok(#struct_name)
}, },
} }
@ -377,7 +394,7 @@ fn parse(data: Data, struct_name: &Ident, attrs: &[Attribute], unchecked: bool)
let discriminant_bits: u64 = match get_attribute_value(attrs, &["discriminant_bits"]) { let discriminant_bits: u64 = match get_attribute_value(attrs, &["discriminant_bits"]) {
Some(attr) => attr, Some(attr) => attr,
None => { None => {
return quote! {span=> return quote_spanned! { span=>
compile_error!("'discriminant_bits' attribute is required when deriving `BinRead` for enums"); compile_error!("'discriminant_bits' attribute is required when deriving `BinRead` for enums");
} }
} }
@ -388,15 +405,24 @@ fn parse(data: Data, struct_name: &Ident, attrs: &[Attribute], unchecked: bool)
let span = variant.span(); let span = variant.span();
let variant_name = &variant.ident; let variant_name = &variant.ident;
let read_fields = match &variant.fields { let read_fields = match &variant.fields {
Fields::Unit => quote_spanned! {span=> Fields::Unit => {
if contains_attribute(&variant.attrs, &["align"]) {
return quote_spanned! { span =>
compile_error!("'align' attribute is not allowed on unit variants");
};
}
quote_spanned! { span=>
#struct_name::#variant_name #struct_name::#variant_name
}, }
}
Fields::Unnamed(f) => { Fields::Unnamed(f) => {
let size = get_field_size(&variant.attrs, f.span()); let size = get_field_size(&variant.attrs, f.span());
let align = get_align(&variant.attrs);
match size { match size {
Some(size) => { Some(size) => {
quote_spanned! { span => quote_spanned! { span =>
#struct_name::#variant_name({ #struct_name::#variant_name({
#align;
let _size:usize = #size; let _size:usize = #size;
stream.read_sized(_size)? stream.read_sized(_size)?
}) })
@ -404,11 +430,14 @@ fn parse(data: Data, struct_name: &Ident, attrs: &[Attribute], unchecked: bool)
} }
None => { None => {
quote_spanned! { span => quote_spanned! { span =>
{
#align;
#struct_name::#variant_name(stream.read()?) #struct_name::#variant_name(stream.read()?)
} }
} }
} }
} }
}
_ => unimplemented!(), _ => unimplemented!(),
}; };
@ -437,6 +466,7 @@ fn parse(data: Data, struct_name: &Ident, attrs: &[Attribute], unchecked: bool)
let enum_name = Lit::Str(LitStr::new(&struct_name.to_string(), struct_name.span())); let enum_name = Lit::Str(LitStr::new(&struct_name.to_string(), struct_name.span()));
quote_spanned! {span=> quote_spanned! {span=>
#align;
let discriminant:#repr = stream.read_int(#discriminant_bits as usize)?; let discriminant:#repr = stream.read_int(#discriminant_bits as usize)?;
Ok(match discriminant { Ok(match discriminant {
#(#match_arms)* #(#match_arms)*
@ -453,6 +483,12 @@ fn parse(data: Data, struct_name: &Ident, attrs: &[Attribute], unchecked: bool)
fn size(data: Data, struct_name: &Ident, attrs: &[Attribute], has_input_size: bool) -> TokenStream { fn size(data: Data, struct_name: &Ident, attrs: &[Attribute], has_input_size: bool) -> TokenStream {
let span = struct_name.span(); let span = struct_name.span();
if contains_attribute(attrs, &["align"]) {
return quote_spanned! { span =>
None
};
}
match data { match data {
Data::Struct(DataStruct { fields, .. }) => { Data::Struct(DataStruct { fields, .. }) => {
let sizes = fields.iter().map(|f| { let sizes = fields.iter().map(|f| {
@ -503,6 +539,7 @@ fn size(data: Data, struct_name: &Ident, attrs: &[Attribute], has_input_size: bo
} }
}; };
// Unit variants having "align" attributes are not allowed, so we can just check if all variants are unit
let is_unit = data let is_unit = data
.variants .variants
.iter() .iter()
@ -526,6 +563,9 @@ fn is_const_size(attrs: &[Attribute], has_input_size: bool) -> bool {
if get_attribute_value::<Lit>(attrs, &["size_bits"]).is_some() { if get_attribute_value::<Lit>(attrs, &["size_bits"]).is_some() {
return false; return false;
} }
if contains_attribute(attrs, &["align"]) {
return false;
}
get_attribute_value(attrs, &["size"]) get_attribute_value(attrs, &["size"])
.map(|size_lit| match size_lit { .map(|size_lit| match size_lit {
Lit::Int(_) => true, Lit::Int(_) => true,
@ -571,3 +611,13 @@ fn repr_for_bits(discriminant_bits: u64) -> TokenStream {
quote!(usize) quote!(usize)
} }
} }
fn get_align(attrs: &[Attribute]) -> TokenStream {
if contains_attribute(attrs, &["align"]) {
quote! {
stream.align()?
}
} else {
quote! { () }
}
}

View file

@ -335,3 +335,79 @@ fn test_bit_size_sized() {
Some(8 + 8 * 16 + 1) Some(8 + 8 * 16 + 1)
); );
} }
#[derive(BitRead, PartialEq, Debug)]
#[align]
struct AlignStruct(u8);
#[test]
fn test_align() {
let bytes = vec![0, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let buffer = BitReadBuffer::new(&bytes, BigEndian);
let mut stream = BitReadStream::from(buffer);
stream.read_bool().unwrap();
assert_eq!(AlignStruct(0x80), stream.read().unwrap());
assert_eq!(16, stream.pos());
assert_eq!(None, bit_size_of::<AlignStruct>());
}
#[derive(BitRead, PartialEq, Debug)]
#[align]
struct AlignFieldStruct {
#[size = 1]
foo: u8,
#[align]
bar: u8,
}
#[test]
fn test_align_field() {
let bytes = vec![0, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let buffer = BitReadBuffer::new(&bytes, BigEndian);
let mut stream = BitReadStream::from(buffer);
assert_eq!(
AlignFieldStruct { foo: 0, bar: 0x80 },
stream.read().unwrap()
);
assert_eq!(16, stream.pos());
assert_eq!(None, bit_size_of::<AlignStruct>());
}
#[derive(BitRead, PartialEq, Debug)]
#[discriminant_bits = 4]
#[align]
enum AlignEnum {
Foo,
Bar(u8),
}
#[test]
fn test_align_enum() {
let bytes = vec![0x00, 0x18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let buffer = BitReadBuffer::new(&bytes, BigEndian);
let mut stream = BitReadStream::from(buffer);
stream.read_bool().unwrap();
assert_eq!(AlignEnum::Bar(0x80), stream.read().unwrap());
assert_eq!(20, stream.pos());
assert_eq!(None, bit_size_of::<AlignEnum>());
}
#[derive(BitRead, PartialEq, Debug)]
#[discriminant_bits = 4]
#[align]
enum AlignEnumField {
Foo,
#[align]
Bar(u8),
}
#[test]
fn test_align_enum_field() {
let bytes = vec![0x00, 0x10, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let buffer = BitReadBuffer::new(&bytes, BigEndian);
let mut stream = BitReadStream::from(buffer);
stream.read_bool().unwrap();
assert_eq!(AlignEnumField::Bar(0x80), stream.read().unwrap());
assert_eq!(24, stream.pos());
assert_eq!(None, bit_size_of::<AlignEnum>());
}